rocm_jax/jax/_src/cudnn/fused_attention_stablehlo.py
2024-12-10 11:11:32 -08:00

1072 lines
40 KiB
Python

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import functools
import json
import math
import jax
from jax import dtypes
from jax._src import core
from jax._src import dispatch
from jax._src.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src.lib import cuda_versions
from jax._src import xla_bridge
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.interpreters.mlir import hlo
from jax.interpreters.mlir import ir
import jax.numpy as jnp
from jax.sharding import NamedSharding, PartitionSpec
Array = jnp.ndarray
class AttentionLayout(enum.Enum):
BTNH = 0
BNTH = 1
class MaskType(enum.Enum):
NO_MASK = 0
PADDING = 1
CAUSAL = 2
PADDING_CAUSAL = 3
ALIBI = 4
def convert_mask_type_to_string(mask_type: MaskType) -> str:
if mask_type == MaskType.NO_MASK:
return "NO_MASK"
elif mask_type == MaskType.PADDING:
return "PADDING"
elif mask_type == MaskType.CAUSAL:
return "CAUSAL"
elif mask_type == MaskType.PADDING_CAUSAL:
return "PADDING_CAUSAL"
elif mask_type == MaskType.ALIBI:
return "ALIBI"
else:
raise ValueError(f"Unexpected mask type: {mask_type}")
def has_padding(mask_type: MaskType) -> bool:
return mask_type == MaskType.PADDING or mask_type == MaskType.PADDING_CAUSAL
def should_export_dbias(bias_shape, query_shape, layout) -> bool:
b_B, b_N, _, _ = bias_shape
if layout == AttentionLayout.BNTH.value:
_, q_N, _, _ = query_shape
else:
_, _, q_N, _ = query_shape
return b_B == 1 and b_N == q_N
def get_large_negative_number(dtype):
# temp WAR as cuDNN has a bug for subtraction between two large negative value
if dtype == jnp.bfloat16:
return jnp.asarray(-2 << 40, dtype=dtype)
elif dtype == jnp.float16:
return jnp.asarray(-2 << 14, dtype=dtype)
else:
raise ValueError("Unsupported dtype for inputs.")
def _normalize_layout(layout: str) -> AttentionLayout:
layout_upper = layout.upper()
if layout_upper in ["BSNH", "BNSH", "BTNH", "BNTH"]:
return AttentionLayout[layout_upper.replace("S", "T")]
else:
raise ValueError(f"Unsupported qkv_layout: {layout}")
def element_type_to_backend_config_type_mapping(dtype):
_element_type_to_backend_config_type_mapping = {
ir.BF16Type.get(): "BF16",
ir.F16Type.get(): "F16",
}
return _element_type_to_backend_config_type_mapping[dtype]
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def create_dot_product_attention_backend_config(batch,
num_heads,
seq_q,
seq_kv,
dtype,
fmha_scale,
seed,
dropout_rate,
mask_type,
layout,
sliding_window_length,
is_bwd):
# Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H
# P: BMM1 output in shape of BNTS
# O: BMM2 output in the same shape with Q
# BMM1: Q @ K -> P
# BMM2: P @ V -> O
# BMM1Grad1: dP @ Q -> dK
# BMM1Grad2: dP @ K -> dQ
# BMM2Grad1: P @ dO -> dV
# BMM2Grad2: dO @ V -> dP
if sliding_window_length is None:
sliding_window_length = 0
cudnn_fmha_backend_config = {
"algorithm": {
"algo_id": "0",
"math_type": "TENSOR_OP_MATH",
"tuning_knobs": {"17": "1", "24": "0"},
"is_cudnn_frontend": True,
"workspace_size": "0",
},
"fmha_scale": fmha_scale,
"dropout_rate": dropout_rate,
"intermediate_tensor_shape": {
"element_type": element_type_to_backend_config_type_mapping(dtype),
"dimensions": [str(batch), str(num_heads), str(seq_q), str(seq_kv)],
"tuple_shapes": [],
"layout": {
"dim_level_types": [],
"dim_unique": [],
"dim_ordered": [],
"minor_to_major": ["3", "2", "1", "0"],
"tiles": [],
"element_size_in_bits": "0",
"memory_space": "0",
"index_primitive_type": "PRIMITIVE_TYPE_INVALID",
"pointer_primitive_type": "PRIMITIVE_TYPE_INVALID",
"dynamic_shape_metadata_prefix_bytes": "0",
},
"is_dynamic_dimension": [False, False, False, False],
},
"seed": seed,
"is_flash_attention": True,
"mask_type": convert_mask_type_to_string(mask_type),
"sliding_window_length": sliding_window_length,
}
# We define the contracting and batch dims in the format of
# ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
# rhs_batch_dims)).
if layout == AttentionLayout.BNTH.value:
dims = [
((3, 3), ((0, 1), (0, 1))), # BMM1: BNTH,BNSH->BNTS
((3, 2), ((0, 1), (0, 1))), # BMM2: BNTS,BNSH->BNTH
((2, 2), ((0, 1), (0, 1))), # BMM1_grad_1: BNTS,BNTH->BNSH
((3, 2), ((0, 1), (0, 1))), # BMM1_grad_2: BNTS,BNSH->BNTH
((2, 2), ((0, 1), (0, 1))), # BMM2_grad_1: BNTS,BNTH->BNSH
((3, 3), ((0, 1), (0, 1))), # BMM2_grad_2: BNTH,BNSH->BNTS
]
else:
dims = [
((3, 3), ((0, 2), (0, 2))), # BMM1: BTNH,BSNH->BNTS
((3, 1), ((0, 1), (0, 2))), # BMM2: BNTS,BSNH->BTNH
((2, 1), ((0, 1), (0, 2))), # BMM1_grad_1: BNTS,BTNH->BSNH
((3, 1), ((0, 1), (0, 2))), # BMM1_grad_2: BNTS,BSNH->BTNH
((2, 1), ((0, 1), (0, 2))), # BMM2_grad_1: BNTS,BTNH->BSNH
((3, 3), ((0, 2), (0, 2))), # BMM2_grad_2: BTNH,BSNH->BNTS
]
keys = [
"bmm1_dot_dimension_numbers",
"bmm2_dot_dimension_numbers",
"bmm1_grad_gemm1_dot_dimension_numbers",
"bmm1_grad_gemm2_dot_dimension_numbers",
"bmm2_grad_gemm1_dot_dimension_numbers",
"bmm2_grad_gemm2_dot_dimension_numbers",
]
fwd_dot_number = {}
bwd_dot_number = {}
for idx, (key, ((lc, rc), (lb, rb))) in enumerate(zip(keys, dims)):
dims_to_write = fwd_dot_number if idx < 2 else bwd_dot_number
dims_to_write[key] = {
"lhs_contracting_dimensions": [str(lc)],
"rhs_contracting_dimensions": [str(rc)],
"lhs_batch_dimensions": [str(i) for i in lb],
"rhs_batch_dimensions": [str(i) for i in rb],
}
if is_bwd:
cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **bwd_dot_number}
else:
cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **fwd_dot_number}
backend_config = {
"operation_queue_id":"0",
"wait_on_operation_queues":[],
"cudnn_fmha_backend_config": cudnn_fmha_backend_config
}
backend_config = json.dumps(backend_config)
return backend_config
# mapping from (is_bwd, has_dropout, has_bias) to custom call name
_custom_name_maps = {
# fMHA forward call targets.
(False, False, False): "__cudnn$fmhaSoftmax",
(False, False, True): "__cudnn$fmhaScaleBiasSoftmax",
(False, True, False): "__cudnn$fmhaSoftmaxDropout",
(False, True, True): "__cudnn$fmhaScaleBiasSoftmaxDropout",
# fMHA backward call targets.
(True, False, False): "__cudnn$fmhaSoftmaxBackward",
(True, False, True): "__cudnn$fmhaScaleBiasSoftmaxBackward",
(True, True, False): "__cudnn$fmhaSoftmaxDropoutBackward",
(True, True, True): "__cudnn$fmhaScaleBiasSoftmaxDropoutBackward",
}
def get_custom_call_name(has_bias, has_dropout, is_bwd):
return _custom_name_maps[(is_bwd, has_dropout, has_bias)]
def check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout):
def check_eq(a, b, c, msg):
if not (a == b == c):
raise ValueError(f"{msg} must be same, got {a}, {b}, {b}")
q_rank, k_rank, v_rank = len(query.shape), len(key.shape), len(value.shape)
if q_rank != 4:
raise ValueError(f"Q must have a rank of 4, got {q_rank}")
check_eq(q_rank, k_rank, v_rank, "QKV rank")
q_dtype, k_dtype, v_dtype = query.dtype, key.dtype, value.dtype
if q_dtype not in [jnp.bfloat16, jnp.float16]:
raise NotImplementedError(f"Q must be fp16 or bf16, got {q_dtype}")
check_eq(q_dtype, k_dtype, v_dtype, "QKV dtype")
if layout == AttentionLayout.BNTH:
qB, qN, qT, qH = query.shape
kB, kN, kS, kH = key.shape
vB, vN, vS, vH = value.shape
else:
assert layout == AttentionLayout.BTNH
qB, qT, qN, qH = query.shape
kB, kS, kN, kH = key.shape
vB, vS, vN, vH = value.shape
check_eq(qB, kB, vB, "QKV batch")
check_eq(qH, kH, vH, "QKV dim_per_head")
if kN != vN:
raise ValueError(f"KV must have same number of heads, got {kN} vs {vN}")
if kS != vS:
raise ValueError(f"KV must have same seq length, got {kS} vs {vS}")
# check bias/q_seqlen/kv_seqlen
if bias is not None:
_, _, bT, bS = bias.shape
if bT != qT or bS != vS:
raise ValueError(
f"Bias must have same seq length as QKV, got {bT} and {bS}")
if q_seqlen is not None:
q_seq_dtype = q_seqlen.dtype
q_seq_rank = len(q_seqlen.shape)
if q_seq_dtype != jnp.int32:
raise ValueError(f"q_seqlen must have int32 datatype, got {q_seq_dtype}")
if q_seq_rank != 1:
raise ValueError(f"q_seqlen must have a rank of 1, got {q_seq_rank}")
q_seq_b = q_seqlen.shape[0]
if q_seq_b != qB:
raise ValueError(f"q_seqlen must have same batch as Q, got {q_seq_b}")
if kv_seqlen is not None:
kv_seq_dtype = kv_seqlen.dtype
kv_seq_rank = len(kv_seqlen.shape)
if kv_seq_dtype != jnp.int32:
raise ValueError(
f"kv_seqlen must have int32 datatype, got {kv_seq_dtype}")
if kv_seq_rank != 1:
raise ValueError(f"kv_seq_rank must have a rank of 1, got {kv_seq_rank}")
kv_seq_b = kv_seqlen.shape[0]
if kv_seq_b != qB:
raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}")
def check_is_flash_attention(
query, key, layout: int, cudnn_version, has_bias, is_training):
if layout == AttentionLayout.BNTH.value:
_, _, T, H = query.shape
_, _, S, _ = key.shape
else:
_, T, _, H = query.shape
_, S, _, _ = key.shape
if (H > 128 or H % 8 != 0 or
(is_training and has_bias and (T % 2 != 0 or S % 2 != 0))):
# check if flash attention is supported
# for training, for patterns with bias, seqlen should be divisible by 2
raise NotImplementedError(
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}.")
# check if minimum cudnn version requirement is satisfied
if cudnn_version < 8904:
raise RuntimeError(
"JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
def check_cudnn_version():
# check if cuDNN is installed
if cuda_versions is None:
raise RuntimeError("cuDNN is not detected.")
return cuda_versions.cudnn_get_version()
def check_compute_capability(capability):
if not 'cuda' in xla_bridge.get_backend().platform_version:
return False
d, *_ = jax.local_devices(backend="gpu")
target = tuple(int(x) for x in capability.split("."))
current = tuple(int(x) for x in d.compute_capability.split("."))
return current >= target
def _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, False)
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=False)
output = outputs[0]
return output
def _dot_product_attention_fwd_rule(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, cudnn_version):
# check if flash attention is supported for this attention pattern
check_is_flash_attention(
query, key, layout, cudnn_version, bias is not None, True)
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=True)
res = (query, key, value, bias, q_seqlen, kv_seqlen,
outputs[1], outputs[0])
return outputs[0], res
def _dot_product_attention_bwd_rule(
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training, res, grad_output):
(query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output) = res
grads = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length
)
grads = (*grads,) + (None,) * (6 - len(grads))
return grads
def _dot_product_attention_fwd_impl(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
# args: {Q, K, V, mask*, bias*}
outputs = _dot_product_attention_fwd_p.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=is_training)
return outputs
def _dot_product_attention_bwd_impl(
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length):
grads = _dot_product_attention_bwd_p.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length)
return grads
def _dot_product_attention_fwd_abstract(
query, key, value, bias, q_seqlen, kv_seqlen, *, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
if layout == AttentionLayout.BNTH.value:
B, N, T, _ = query.shape
_, _, S, _ = key.shape
else:
B, T, N, _ = query.shape
_, S, _, _ = key.shape
output_shape = query.shape
softmax_stat_shape = (B, N, T)
if is_training:
return (
core.ShapedArray(output_shape, query_dtype), # output
core.ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat
)
else:
return (
core.ShapedArray(output_shape, query_dtype), # output
)
def _dot_product_attention_bwd_abstract(
query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output,
grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type,
layout, sliding_window_length):
query_dtype = dtypes.canonicalize_dtype(query.dtype)
key_dtype = dtypes.canonicalize_dtype(key.dtype)
value_dtype = dtypes.canonicalize_dtype(value.dtype)
_, has_dbias = variadic_args
if has_dbias:
# cuDNN supports bias for this case
bias_dtype = dtypes.canonicalize_dtype(bias.dtype)
return (
core.ShapedArray(
query.shape, query_dtype
), # grad query
core.ShapedArray(
key.shape, key_dtype
), # grad key
core.ShapedArray(
value.shape, value_dtype
), # grad value
core.ShapedArray(
bias.shape, bias_dtype
), # grad bias
)
else:
return (
core.ShapedArray(
query.shape, query_dtype
), # grad query
core.ShapedArray(
key.shape, key_dtype
), # grad key
core.ShapedArray(
value.shape, value_dtype
), # grad value
)
def _dot_product_attention_fwd_cuda_lowering(
ctx, query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, is_training):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
key_shape = key_type.shape
if layout == AttentionLayout.BNTH.value:
B, N, T, H = query_shape
_, _, S, _ = key_shape
output_layout = (3, 2, 1, 0)
output_transpose_perm = mlir.dense_int_array((0, 1, 2, 3))
else:
B, T, N, H = query_shape
_, S, _, _ = key_shape
output_layout = (3, 1, 2, 0)
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
output_shape = (B, N, T, H)
softmax_stat_shape = (B, N, T)
workspace_shape = (0,)
workspace_type = ir.IntegerType.get_unsigned(8)
backend_config = create_dot_product_attention_backend_config(
B, N, T, S, query_type.element_type, scale, seed, dropout_rate,
mask_type, layout, sliding_window_length, is_bwd=False,
)
# {Q, K, V, bias*, q_seqlen*, kv_seqlen*}
# {output, activation*, workspace}
has_dropout = dropout_rate > 0
has_bias, _ = variadic_args
operands = [query, key, value]
if has_bias:
operands.append(bias)
if has_padding(mask_type):
operands.append(q_seqlen)
operands.append(kv_seqlen)
custom_call_name = get_custom_call_name(has_bias, has_dropout, False)
# create output types and layouts
if is_training:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()),
ir.RankedTensorType.get(workspace_shape, workspace_type),
]
result_layouts = [output_layout] + default_layouts(softmax_stat_shape, workspace_shape)
else:
result_types = [
ir.RankedTensorType.get(output_shape, query_type.element_type),
ir.RankedTensorType.get(workspace_shape, workspace_type)
]
result_layouts = [output_layout] + default_layouts(workspace_shape)
# create custom call here
out = mlir.custom_call(
custom_call_name,
result_types=result_types,
operands=operands,
backend_config=backend_config,
operand_layouts=default_layouts(
*[ir.RankedTensorType(operand.type).shape for operand in operands]),
result_layouts=result_layouts,
)
# drop workspace memory
# output should be (B, T, N, H) instead of (B, N, T, H)
if is_training:
return [hlo.transpose(out.results[0], output_transpose_perm), out.results[1]]
else:
return [hlo.transpose(out.results[0], output_transpose_perm)]
def _dot_product_attention_bwd_cuda_lowering(
ctx, query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length):
query_type = ir.RankedTensorType(query.type)
query_shape = query_type.shape
key_type = ir.RankedTensorType(key.type)
key_shape = key_type.shape
value_type = ir.RankedTensorType(value.type)
if layout == AttentionLayout.BNTH.value:
B, q_N, T, H = query_shape
_, k_N, S, _ = key_shape
grad_layout = (3, 2, 1, 0)
grad_transpose_perm = mlir.dense_int_array((0, 1, 2, 3))
else:
B, T, q_N, H = query_shape
_, S, k_N, _ = key_shape
grad_layout = (3, 1, 2, 0)
grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
workspace_shape = (0,)
workspace_type = ir.IntegerType.get_unsigned(8)
grad_query_shape = (B, q_N, T, H)
grad_key_shape = (B, k_N, S, H)
grad_value_shape = (B, k_N, S, H)
backend_config = create_dot_product_attention_backend_config(
B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate,
mask_type, layout, sliding_window_length, is_bwd=True,
)
# {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*}
# {dQ, dK, dV, dbias*, workspace}
has_dropout = dropout_rate > 0
has_bias, has_dbias = variadic_args
# create operands
operands = [query, key, value, activation, grad_output]
if has_bias:
# flash attention requires bias in the bwd for remat
operands.append(bias)
operands.append(fwd_output)
if has_padding(mask_type):
operands.append(q_seqlen)
operands.append(kv_seqlen)
# get custom call name
custom_call_name = get_custom_call_name(has_bias, has_dropout, True)
# create output types and layouts
# grad_query, grad_key, grad_value
result_types = [
ir.RankedTensorType.get(grad_query_shape, query_type.element_type),
ir.RankedTensorType.get(grad_key_shape, key_type.element_type),
ir.RankedTensorType.get(grad_value_shape, value_type.element_type),
]
result_layouts = [grad_layout, grad_layout, grad_layout]
bias_type = ir.RankedTensorType(bias.type)
bias_shape = bias_type.shape
if has_dbias:
# cuDNN supports bias for this case
result_types.append(
ir.RankedTensorType.get(bias_shape, bias_type.element_type))
result_layouts = result_layouts + default_layouts(bias_shape)
# workspace
result_types.append(ir.RankedTensorType.get(workspace_shape, workspace_type))
result_layouts = result_layouts + default_layouts(workspace_shape)
out = mlir.custom_call(
custom_call_name,
result_types=result_types,
operands=operands,
backend_config=backend_config,
operand_layouts=default_layouts(
*[ir.RankedTensorType(operand.type).shape for operand in operands]),
result_layouts=result_layouts,
)
dqkv = (hlo.transpose(out.results[0], grad_transpose_perm),
hlo.transpose(out.results[1], grad_transpose_perm),
hlo.transpose(out.results[2], grad_transpose_perm))
# Only keep dQ, dK, dV and dBias here
if has_dbias:
return dqkv + (out.results[3],)
else:
return dqkv
# batcher
def _check_valid_batch_dims(bdims):
for dim in bdims:
if dim not in [0, None]:
raise NotImplementedError(
f"Currently only support batch_dim in [0, None], but got {dim=}")
def _dot_product_attention_fwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length, is_training):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, q_seqlen, kv_seqlen = batched_args
query_bdim = batch_dims[0]
if is_training:
out_bdims = query_bdim, query_bdim
else:
out_bdims = (query_bdim,)
if layout == AttentionLayout.BNTH.value:
*Bs, N, T, _ = query.shape
*_, _, S, _ = key.shape
else:
*Bs, T, N, _ = query.shape
*_, S, _, _ = key.shape
B = math.prod(Bs)
has_bias, _ = variadic_args
original_shape = query.shape
# reshape to 4D shape
query = jnp.reshape(query, (B,) + query.shape[-3:])
key = jnp.reshape(key, (B,) + key.shape[-3:])
value = jnp.reshape(value, (B,) + key.shape[-3:])
if has_bias and batch_dims[3] is not None:
bias = jnp.reshape(bias, (B, N, T, S))
if has_padding(mask_type):
q_seqlen = jnp.reshape(q_seqlen, (B, ))
kv_seqlen = jnp.reshape(kv_seqlen, (B, ))
outputs = _dot_product_attention_fwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length, is_training=is_training)
# reshape to original shape
output = outputs[0]
output = jnp.reshape(output, original_shape)
if is_training:
activation = outputs[1]
activation = jnp.reshape(activation, (*Bs, N, T))
return (output, activation), out_bdims
else:
return (output,), out_bdims
def _dot_product_attention_bwd_batcher(
batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args,
mask_type, layout, sliding_window_length):
_check_valid_batch_dims(batch_dims)
query, key, value, bias, q_seqlen, \
kv_seqlen, activation, fwd_output, grad_output = batched_args
query_bdim = batch_dims[0]
out_bdims = query_bdim, query_bdim, query_bdim
if layout == AttentionLayout.BNTH.value:
*Bs, N, T, _ = query.shape
*_, _, S, _ = key.shape
else:
*Bs, T, N, _ = query.shape
*_, S, _, _ = key.shape
B = math.prod(Bs)
has_bias, has_dbias = variadic_args
# Reset the has_dbias if the combined batch size is not 1, because cuDNN only
# supports dbias with a single batch. In this case, an all-zero dbias will be
# appended instead.
if B > 1:
variadic_args = (has_bias, False)
original_query_shape = query.shape
original_key_shape = key.shape
original_value_shape = value.shape
original_bias_shape = bias.shape if has_bias else None
# reshape to 4D shape
query = jnp.reshape(query, (B,) + query.shape[-3:])
key = jnp.reshape(key, (B,) + key.shape[-3:])
value = jnp.reshape(value, (B,) + key.shape[-3:])
if has_bias and batch_dims[3] is not None:
bias = jnp.reshape(bias, (B, N, T, S))
if has_padding(mask_type):
q_seqlen = jnp.reshape(q_seqlen, (B, ))
kv_seqlen = jnp.reshape(kv_seqlen, (B, ))
activation = jnp.reshape(activation, (B, N, T))
fwd_output = jnp.reshape(fwd_output, (B,) + query.shape[-3:])
grad_output = jnp.reshape(grad_output, (B,) + query.shape[-3:])
grads = _dot_product_attention_bwd_p_wrapper.bind(
query, key, value, bias, q_seqlen, kv_seqlen, activation,
fwd_output, grad_output, scale=scale, seed=seed,
dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout,
sliding_window_length=sliding_window_length,
)
# reshape to original shape
grads[0] = jnp.reshape(grads[0], original_query_shape)
grads[1] = jnp.reshape(grads[1], original_key_shape)
grads[2] = jnp.reshape(grads[2], original_value_shape)
if has_dbias:
assert has_bias
if variadic_args[1]:
grads[3] = jnp.reshape(grads[3], original_bias_shape)
else:
grads.append(jnp.zeros(original_bias_shape, bias.dtype))
out_bdims += (batch_dims[3],)
return grads, out_bdims
# custom partitioning
def _get_padded_spec(arg_info):
spec = None if arg_info.sharding is None else arg_info.sharding.spec
ndim = arg_info.ndim
if spec is None:
return (None,) * ndim
assert len(spec) <= ndim
return spec + (None,) * (ndim - len(spec))
def _check_qkv_bias_mask_spec(
query_spec, key_spec, value_spec, bias_spec):
# check qkv spec
if not query_spec == key_spec == value_spec:
raise ValueError("Query, key and value should have same sharding.")
*batch_spec, q_seq_spec, num_head_spec, head_spec = query_spec
if q_seq_spec is not None:
raise ValueError("Sharding on sequence dim is not allowed.")
if head_spec is not None:
raise ValueError("Sharding on head dim is not allowed.")
# check bias spec
if bias_spec:
*bias_batch_spec, bias_num_head_spec, bias_q_seq_spec, bias_kv_seq_spec = bias_spec
if any(bias_batch_spec) and bias_batch_spec != batch_spec or \
bias_num_head_spec is not None and bias_num_head_spec != num_head_spec:
raise ValueError(
"Query and bias should have same sharding on batch and num_head dim.")
if bias_q_seq_spec is not None or bias_kv_seq_spec is not None:
raise ValueError("Sharding on bias sequence dim is not allowed.")
# fwd custom partition
def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training):
# only sharding on batch and num_head dim is allowed
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
key_spec = _get_padded_spec(arg_shapes[1])
value_spec = _get_padded_spec(arg_shapes[2])
has_bias, _ = variadic_args
bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None
_check_qkv_bias_mask_spec(
query_spec, key_spec, value_spec, bias_spec)
# keep out sharding same as query sharding since they have same shape
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
if is_training:
# activation sharding
*batch_spec, q_seq_spec, num_head_spec, _ = query_spec
activation_sharding = NamedSharding(
mesh, PartitionSpec(*batch_spec, num_head_spec, q_seq_spec, None))
return [out_sharding, activation_sharding]
return [out_sharding]
_dot_product_attention_fwd_lower = custom_partitioning(
_dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
def _dot_product_attention_fwd_infer_sharding_from_operands(
scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length,
is_training, mesh, arg_shapes, result_shape):
return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training)
def _dot_product_attention_fwd_partition(
scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length,
is_training, mesh, arg_shapes, result_shape):
# args sharding
arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes)
out_shardings = _infer_fwd_output_sharding(
mesh, arg_shapes, variadic_args, is_training)
impl = functools.partial(
_dot_product_attention_fwd_impl,
scale=scale,
seed=seed,
dropout_rate=dropout_rate,
variadic_args=variadic_args,
mask_type=mask_type,
layout=layout,
sliding_window_length=sliding_window_length,
is_training=is_training,
)
return mesh, impl, out_shardings, arg_shardings
# bwd custom partition
def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args):
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
key_spec = _get_padded_spec(arg_shapes[1])
value_spec = _get_padded_spec(arg_shapes[2])
has_bias, has_dbias = variadic_args
bias_spec = _get_padded_spec(arg_shapes[3]) if has_bias else None
_check_qkv_bias_mask_spec(
query_spec, key_spec, value_spec, bias_spec)
# keep grad query sharding same as query sharding
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
out_shardings = [grad_query_sharding, grad_key_sharding, grad_value_sharding]
if has_dbias:
grad_bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec))
out_shardings = out_shardings + [grad_bias_sharding]
return out_shardings
_dot_product_attention_bwd_lower = custom_partitioning(
_dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14, 15)
)
def _dot_product_attention_bwd_infer_sharding_from_operands(
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, mesh, arg_shapes, result_shape):
return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
def _dot_product_attention_bwd_partition(
scale, seed, dropout_rate, variadic_args, mask_type, layout,
sliding_window_length, mesh, arg_shapes, result_shape):
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args)
# args sharding
arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes)
def sharded_impl(*args):
impl = functools.partial(
_dot_product_attention_bwd_impl,
scale=scale,
seed=seed,
dropout_rate=dropout_rate,
variadic_args=variadic_args,
mask_type=mask_type,
layout=layout,
sliding_window_length=sliding_window_length,
)
grads = impl(*args)
_, has_dbias = variadic_args
if has_dbias:
query_spec = arg_shardings[0].spec
batch_spec = query_spec[0]
local_dbias = grads[3]
global_dbias = jax.lax.psum(local_dbias, batch_spec)
grads = grads[:3] + [global_dbias]
return grads
return mesh, sharded_impl, out_shardings, arg_shardings
# Create dot_product_attention_fwd_p for forward operation.
_dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd")
_dot_product_attention_fwd_p.multiple_results = True
_dot_product_attention_fwd_p.def_impl(
functools.partial(xla.apply_primitive, _dot_product_attention_fwd_p)
)
_dot_product_attention_fwd_p.def_abstract_eval(
_dot_product_attention_fwd_abstract
)
mlir.register_lowering(
_dot_product_attention_fwd_p,
_dot_product_attention_fwd_cuda_lowering,
platform="cuda",
)
_dot_product_attention_fwd_p_wrapper = core.Primitive(
"dot_product_attention_fwd_wrapper"
)
_dot_product_attention_fwd_p_wrapper.multiple_results = True
_dot_product_attention_fwd_p_wrapper.def_impl(_dot_product_attention_fwd_impl)
_dot_product_attention_fwd_p_wrapper.def_abstract_eval(
_dot_product_attention_fwd_abstract
)
# Create dot_product_attention_bwd_p for backward operation.
_dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd")
_dot_product_attention_bwd_p.multiple_results = True
_dot_product_attention_bwd_p.def_impl(
functools.partial(xla.apply_primitive, _dot_product_attention_bwd_p)
)
_dot_product_attention_bwd_p.def_abstract_eval(
_dot_product_attention_bwd_abstract
)
mlir.register_lowering(
_dot_product_attention_bwd_p,
_dot_product_attention_bwd_cuda_lowering,
platform="cuda",
)
_dot_product_attention_bwd_p_wrapper = core.Primitive(
"dot_product_attention_bwd_wrapper"
)
_dot_product_attention_bwd_p_wrapper.multiple_results = True
_dot_product_attention_bwd_p_wrapper.def_impl(_dot_product_attention_bwd_impl)
_dot_product_attention_bwd_p_wrapper.def_abstract_eval(
_dot_product_attention_bwd_abstract
)
batching.primitive_batchers[
_dot_product_attention_fwd_p_wrapper
] = _dot_product_attention_fwd_batcher
batching.primitive_batchers[
_dot_product_attention_bwd_p_wrapper
] = _dot_product_attention_bwd_batcher
_dot_product_attention_fwd_lower.def_partition(
infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands,
partition=_dot_product_attention_fwd_partition)
mlir.register_lowering(_dot_product_attention_fwd_p_wrapper,
mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True))
_dot_product_attention_bwd_lower.def_partition(
infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands,
partition=_dot_product_attention_bwd_partition)
mlir.register_lowering(_dot_product_attention_bwd_p_wrapper,
mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True))
dispatch.prim_requires_devices_during_lowering.add(
_dot_product_attention_fwd_p
)
dispatch.prim_requires_devices_during_lowering.add(
_dot_product_attention_fwd_p_wrapper
)
dispatch.prim_requires_devices_during_lowering.add(
_dot_product_attention_bwd_p
)
dispatch.prim_requires_devices_during_lowering.add(
_dot_product_attention_bwd_p_wrapper
)
@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13))
def _dot_product_attention(query: Array,
key: Array,
value: Array,
bias: Array,
q_seqlen: Array,
kv_seqlen: Array,
scale: float,
seed: int,
dropout_rate: float,
variadic_args: tuple[bool, ...],
mask_type: bool,
layout: int,
sliding_window_length: int | None,
cudnn_version: int):
output = _dot_product_attention_fwd(
query, key, value, bias, q_seqlen, kv_seqlen, scale=scale,
seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length,
cudnn_version=cudnn_version)
return output
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule)
# User interface
def dot_product_attention(query: Array,
key: Array,
value: Array,
bias: Array | None = None,
mask: Array | None = None,
q_seqlen: Array | None = None,
kv_seqlen: Array | None = None,
*,
scale: float = 1.0,
mask_type: MaskType = MaskType.NO_MASK,
seed: int = 42,
dropout_rate: float = 0.,
qkv_layout: str = "BTNH",
sliding_window_length: int | None = None):
"""Computes dot-product attention given query (Q), key (K), and value (V).
This function serves as the core operation for applying attention
mechanisms as described in the paper [https://arxiv.org/abs/1706.03762].
Initially, it determines the attention weights by processing Q and K,
subsequently combining the outcomes using K. Throughout this function, we
utilize the following uppercase letters to represent specific parameters of
array:
B = batch size
S = length of the key/value (source)
T = length of the query (target)
N = number of attention heads
H = dimensions of each attention head.
The supported layouts for Q, K, V are either BT(S)NH or BNT(S)H, and they must
adhere to the same layout. The output layout remains consistent with Q,
defaulting to BT(S)NH.
Args:
query: Queries for attention calculation with a shape of BTNH or BNTH.
key: Keys for attention calculation with a shape of BSNH or BNSH.
value: Values to be used in attention with a shape of BSNH or BNSH.
bias: Bias to be added to logits with a shape of BNTS.
mask: Mask used to filter out logits with a shape of BNTS.
q_seqlen: Non padded sequence length of Queries with a shape of B.
kv_seqlen: Non padded sequence length of Keys and Values with a shape of B.
scale: Scale for the query.
dropout_rate: Dropout rate.
qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH,
BNSH.
sliding_window_length: Window size to make attention only attend to each
token's left local window (pos - sliding_window_length, pos] where `pos`
is the index of each token. E.g., if sliding_window_length == 3 and the
sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c].
Returns:
Output of the same shape as the query.
"""
# TODO(b/380898464): Check the compute capability, e.g., require GPU device,
# in the kernel implementation (c++) code.
cudnn_version = check_cudnn_version()
layout = _normalize_layout(qkv_layout)
if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None):
raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask")
if sliding_window_length is not None and sliding_window_length <= 0:
raise ValueError(
f"Require sliding_window_length > 0, got {sliding_window_length}")
if bias is not None:
# reshape bias to have 4D shape
bias = bias.reshape((1,) * (4 - len(bias.shape)) + bias.shape)
if mask is not None:
if mask.dtype == jnp.bool:
large_negative_number = get_large_negative_number(query.dtype)
mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number)
# reshape mask to have 4D shape
mask = mask.reshape((1,) * (4 - len(mask.shape)) + mask.shape) # type: ignore[union-attr]
# combine bias and mask
if bias is None:
bias = mask
else:
if mask is not None:
# should be broadcast to same shape
bias = bias + mask
# check if input shape and data type is compatiable
check_layout(query, key, value, bias, q_seqlen, kv_seqlen, layout)
has_bias = bias is not None
has_dbias = has_bias and \
should_export_dbias(bias.shape, query.shape, layout) # type: ignore[union-attr]
variadic_args = (has_bias, has_dbias)
if bias is None:
bias = jnp.zeros(0, dtype=query.dtype)
if q_seqlen is None:
q_seqlen = jnp.zeros(0, dtype=query.dtype)
if kv_seqlen is None:
kv_seqlen = jnp.zeros(0, dtype=query.dtype)
output = _dot_product_attention(
query, key, value, bias, q_seqlen, kv_seqlen, scale, seed,
dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length,
cudnn_version)
return output