mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
address some format issues
This commit is contained in:
parent
6957d26dd3
commit
5708fb955b
@ -1,672 +0,0 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial, reduce
|
||||
import operator
|
||||
from typing import Any, Optional
|
||||
import json
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core, dtypes
|
||||
from jax.interpreters import mlir, xla
|
||||
from jax.interpreters.mlir import ir
|
||||
from jaxlib.hlo_helpers import custom_call
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.core import ShapedArray
|
||||
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import Mesh, PartitionSpec, NamedSharding
|
||||
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import cuda_versions
|
||||
|
||||
Array = jnp.ndarray
|
||||
DType = jnp.dtype
|
||||
PRNGKey = jnp.ndarray
|
||||
|
||||
def element_type_to_backend_config_type_mapping(dtype):
|
||||
_element_type_to_backend_config_type_mapping = {
|
||||
ir.BF16Type.get(): "BF16",
|
||||
ir.F16Type.get(): "F16",
|
||||
}
|
||||
return _element_type_to_backend_config_type_mapping.get(dtype)
|
||||
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
|
||||
def create_dot_product_attention_backend_config(batch,
|
||||
num_heads,
|
||||
seq_q,
|
||||
seq_kv,
|
||||
dtype,
|
||||
fmha_scale,
|
||||
seed,
|
||||
dropout_rate,
|
||||
is_flash_attention,
|
||||
is_causal_mask,
|
||||
is_bwd):
|
||||
# b q_seq num_heads head_dim -> Q
|
||||
# b kv_seq num_heads head_dim -> K
|
||||
# b kv_seq num_heads head_dim -> V
|
||||
# b num_heads q_seq kv_seq -> P
|
||||
# b q_seq num_heads head_dim -> O
|
||||
# bmm1: Q @ K -> P
|
||||
# bmm2: P @ V -> O
|
||||
# bmm2Grad1: P @ dO -> dV
|
||||
# bmm2Grad2: dO @ V -> dP
|
||||
# bmm1Grad1: dP @ Q -> dK
|
||||
# bmm1Grad2: dP @ K -> dQ
|
||||
backend_config = {
|
||||
"algorithm":{"algo_id":"0","math_type":"TENSOR_OP_MATH","tuning_knobs":{"17":"1","24":"0"},"is_cudnn_frontend":True,"workspace_size":"0"},
|
||||
"fmha_scale":fmha_scale,
|
||||
"dropout_rate":dropout_rate,
|
||||
"intermediate_tensor_shape":{"element_type":element_type_to_backend_config_type_mapping(dtype),"dimensions":[str(batch),str(num_heads),str(seq_q),str(seq_kv)],"tuple_shapes":[],"layout":{"dim_level_types":[],"dim_unique":[],"dim_ordered":[],"minor_to_major":["3","2","1","0"],"tiles":[],"element_size_in_bits":"0","memory_space":"0","index_primitive_type":"PRIMITIVE_TYPE_INVALID","pointer_primitive_type":"PRIMITIVE_TYPE_INVALID","dynamic_shape_metadata_prefix_bytes":"0"},"is_dynamic_dimension":[False,False,False,False]},
|
||||
"seed":seed,
|
||||
"is_flash_attention":is_flash_attention,
|
||||
"is_causal_mask":is_causal_mask
|
||||
}
|
||||
fwd_dot_number = {
|
||||
"bmm1_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["3"],"lhs_batch_dimensions":["0","2"],"rhs_batch_dimensions":["0","2"]},
|
||||
"bmm2_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]},
|
||||
}
|
||||
bwd_dot_number = {
|
||||
"bmm1_grad_gemm1_dot_dimension_numbers":{"lhs_contracting_dimensions":["2"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]},
|
||||
"bmm1_grad_gemm2_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]},
|
||||
"bmm2_grad_gemm1_dot_dimension_numbers":{"lhs_contracting_dimensions":["2"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0","1"],"rhs_batch_dimensions":["0","2"]},
|
||||
"bmm2_grad_gemm2_dot_dimension_numbers":{"lhs_contracting_dimensions":["3"],"rhs_contracting_dimensions":["3"],"lhs_batch_dimensions":["0","2"],"rhs_batch_dimensions":["0","2"]},
|
||||
}
|
||||
if is_bwd:
|
||||
backend_config = {**backend_config, **bwd_dot_number}
|
||||
else:
|
||||
backend_config = {**backend_config, **fwd_dot_number}
|
||||
|
||||
backend_config = json.dumps(backend_config)
|
||||
return backend_config
|
||||
|
||||
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
|
||||
index = is_bwd << 3 | has_dropout << 2 | has_mask << 1 | has_bias
|
||||
_custom_name_maps = [
|
||||
# fMHA forward call targets.
|
||||
"__cudnn$fhmaSoftmax",
|
||||
"__cudnn$fhmaScaleBiasSoftmax",
|
||||
"__cudnn$fhmaScaleMaskSoftmax",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmax",
|
||||
"__cudnn$fhmaSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
|
||||
# fMHA backward call targets.
|
||||
"__cudnn$fhmaSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
|
||||
"__cudnn$fhmaSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
|
||||
]
|
||||
return _custom_name_maps[index]
|
||||
|
||||
def check_qkv_layout(query, key, value):
|
||||
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
|
||||
"query, key and value should have rank 4."
|
||||
|
||||
# Only support fp16 and bf16 here
|
||||
query_dtype = query.dtype
|
||||
key_dtype = key.dtype
|
||||
value_dtype = value.dtype
|
||||
assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \
|
||||
"query, key and value should have same dtype and should be float16 or bfloat16"
|
||||
|
||||
q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape
|
||||
k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape
|
||||
v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape
|
||||
assert (q_batch == k_batch == v_batch) \
|
||||
and (k_seq_len == v_seq_len) \
|
||||
and (q_num_heads == k_num_heads == v_num_heads) \
|
||||
and (q_head_dim == k_head_dim == v_head_dim), \
|
||||
"query should have layout [batch, q_seq, num_heads, head_dim], " \
|
||||
"key and value should have layout [batch, kv_seq, num_heads, head_dim]."
|
||||
|
||||
def check_is_flash_attention(query, key):
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_sqe_len, _, _ = key.shape
|
||||
# check if attention pattern is supported by flash attention or fused attention
|
||||
if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]:
|
||||
# check if flash attention is supported
|
||||
is_flash_attention = True
|
||||
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
|
||||
# check if regular fused attention is supported
|
||||
is_flash_attention = False
|
||||
else:
|
||||
raise NotImplementedError("Unsupported sequence length and head dim.")
|
||||
return is_flash_attention
|
||||
|
||||
def check_cuDNN_version(is_flash_attention):
|
||||
# check if cuDNN is installed and if cuDNN version contraint is satisfied
|
||||
if cuda_versions is None:
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903:
|
||||
raise RuntimeError("Require cuDNN at lease 8.9.3 to run flash attention.")
|
||||
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
|
||||
raise RuntimeError("Require cuDNN at lease 8.9.1 to run fused attention.")
|
||||
|
||||
def _dot_product_attention_fwd(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
output, _ = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
return output
|
||||
|
||||
def _dot_product_attention_fwd_rule(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
res = (query, key, value, bias, mask, activation, output)
|
||||
return output, res
|
||||
|
||||
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, res, grad_output):
|
||||
# {Q, K, V, bias, mask, activation, fwd_output, dO}
|
||||
query, key, value, bias, mask, activation, fwd_output = res
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
grads = (grad_query, grad_key, grad_value, None, None)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_impl(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
# args: {Q, K, V, mask*, bias*}
|
||||
output, activation = _dot_product_attention_fwd_p.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
return output, activation
|
||||
|
||||
def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p.bind(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
grads = (grad_query, grad_key, grad_value)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
output_shape = (batch, q_seq_len, num_heads, head_dim)
|
||||
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
|
||||
softmax_stat_shape = (batch, num_heads, q_seq_len)
|
||||
if q_seq_len > 512:
|
||||
# is flash attention
|
||||
return (
|
||||
ShapedArray(output_shape, query_dtype), # output
|
||||
ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat
|
||||
)
|
||||
else:
|
||||
return (
|
||||
ShapedArray(output_shape, query_dtype), # output
|
||||
ShapedArray(activation_shape, query_dtype), # activation
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
key_dtype = dtypes.canonicalize_dtype(key.dtype)
|
||||
value_dtype = dtypes.canonicalize_dtype(value.dtype)
|
||||
|
||||
return (
|
||||
ShapedArray(
|
||||
query.shape, query_dtype
|
||||
), # grad query
|
||||
ShapedArray(
|
||||
key.shape, key_dtype
|
||||
), # grad key
|
||||
ShapedArray(
|
||||
value.shape, value_dtype
|
||||
), # part value
|
||||
)
|
||||
|
||||
def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
key_shape = key_type.shape
|
||||
value_type = ir.RankedTensorType(value.type)
|
||||
value_shape = value_type.shape
|
||||
|
||||
batch, q_seq_len, num_heads, head_dim = query_shape
|
||||
_, kv_seq_len, _, _ = key_shape
|
||||
|
||||
output_shape = (batch, num_heads, q_seq_len, head_dim)
|
||||
output_layout = (3, 1, 2, 0)
|
||||
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
|
||||
softmax_stat_shape = (batch, num_heads, q_seq_len)
|
||||
scratch_shape = (0,)
|
||||
scratch_type = ir.IntegerType.get_unsigned(8)
|
||||
# get backend config
|
||||
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, False)
|
||||
# {Q, K, V, mask*, bias*}
|
||||
# {output, scratch, activation*}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, has_mask = variadic_args
|
||||
operands = [query, key, value]
|
||||
if has_mask:
|
||||
operands.append(mask)
|
||||
if has_bias:
|
||||
operands.append(bias)
|
||||
# get custom call name
|
||||
custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, False)
|
||||
# create output types and layouts
|
||||
if is_flash_attention:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape)
|
||||
else:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(activation_shape, query_type.element_type),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape)
|
||||
# create custom call here
|
||||
out = custom_call(
|
||||
custom_call_name,
|
||||
result_types=result_types,
|
||||
operands=operands,
|
||||
backend_config=backend_config,
|
||||
operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]),
|
||||
result_layouts=result_layouts,
|
||||
)
|
||||
# dropout scratch memory
|
||||
# output should be (batch, q_seq_len, num_heads, head_dim) instead of (batch, num_heads, q_seq_len, head_dim)
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]]
|
||||
|
||||
def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
key_shape = key_type.shape
|
||||
value_type = ir.RankedTensorType(value.type)
|
||||
value_shape = value_type.shape
|
||||
activation_type = ir.RankedTensorType(activation.type)
|
||||
activation_shape = activation_type.shape
|
||||
grad_output_type = ir.RankedTensorType(grad_output.type)
|
||||
grad_output_shape = grad_output_type.shape
|
||||
|
||||
batch, q_seq_len, num_heads, head_dim = query_shape
|
||||
_, kv_seq_len, _, _ = key_shape
|
||||
scratch_shape = (0,)
|
||||
scratch_type = ir.IntegerType.get_unsigned(8)
|
||||
|
||||
grad_query_shape = (batch, num_heads, q_seq_len, head_dim)
|
||||
grad_key_shape = (batch, num_heads, kv_seq_len, head_dim)
|
||||
grad_value_shape = (batch, num_heads, kv_seq_len, head_dim)
|
||||
softmax_sum_shape = (batch, num_heads, q_seq_len)
|
||||
grad_layout = (3, 1, 2, 0)
|
||||
grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, True)
|
||||
# {Q, K, V, activation, dO, mask*, bias*, O*}
|
||||
# {dQ, dK, dV, d_S*, softmax_sum*, d_Q_accum*, scratch, dbias*}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, has_mask = variadic_args
|
||||
# create operands
|
||||
operands = [query, key, value, activation, grad_output]
|
||||
if has_mask:
|
||||
operands.append(mask)
|
||||
if has_bias and is_flash_attention:
|
||||
# flash attention requires bias in the bwd for remat
|
||||
operands.append(bias)
|
||||
if is_flash_attention:
|
||||
operands.append(fwd_output)
|
||||
# get custom call name
|
||||
custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, True)
|
||||
|
||||
# create output types and layouts
|
||||
if is_flash_attention:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query
|
||||
ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key
|
||||
ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value
|
||||
ir.RankedTensorType.get(softmax_sum_shape, ir.F32Type.get()), # softmax_sum
|
||||
ir.RankedTensorType.get(grad_query_shape, ir.F32Type.get()), # d_Q_accum
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch
|
||||
]
|
||||
result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(softmax_sum_shape, grad_query_shape, scratch_shape)
|
||||
else:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query
|
||||
ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key
|
||||
ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value
|
||||
ir.RankedTensorType.get(activation_shape, activation_type.element_type), # dS
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch
|
||||
]
|
||||
result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(activation_shape, scratch_shape)
|
||||
out = custom_call(
|
||||
custom_call_name,
|
||||
result_types=result_types,
|
||||
operands=operands,
|
||||
backend_config=backend_config,
|
||||
operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]),
|
||||
result_layouts=result_layouts,
|
||||
)
|
||||
# Only keep dQ, dK and dV here
|
||||
return [hlo.transpose(out.results[0], grad_transpose_perm),
|
||||
hlo.transpose(out.results[1], grad_transpose_perm),
|
||||
hlo.transpose(out.results[2], grad_transpose_perm)]
|
||||
|
||||
# batcher
|
||||
def _check_valid_batch_dims(bdims):
|
||||
for dim in bdims:
|
||||
assert dim in [0, None], \
|
||||
"Currently only support batch_dim in [0, None], " \
|
||||
f"but got {dim=}"
|
||||
|
||||
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, mask = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim
|
||||
|
||||
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
|
||||
*_, kv_seq_len, _, _ = key.shape
|
||||
batch = reduce(operator.mul, batch_tuple)
|
||||
has_bias, has_mask = variadic_args
|
||||
# reshape to 4D shape
|
||||
query = jnp.reshape(query, (batch, q_seq_len, num_heads, head_dim))
|
||||
key = jnp.reshape(key, (batch, kv_seq_len, num_heads, head_dim))
|
||||
value = jnp.reshape(value, (batch, kv_seq_len, num_heads, head_dim))
|
||||
if has_bias:
|
||||
bias = jnp.reshape(bias, (batch, num_heads, q_seq_len, kv_seq_len))
|
||||
if has_mask:
|
||||
mask = jnp.reshape(mask, (batch, num_heads, q_seq_len, kv_seq_len))
|
||||
|
||||
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
|
||||
# reshape to original shape
|
||||
output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim))
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
|
||||
else:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
|
||||
return (output, activation), out_bdims
|
||||
|
||||
def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim, query_bdim
|
||||
|
||||
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
|
||||
*_, kv_seq_len, _, _ = key.shape
|
||||
batch = reduce(operator.mul, batch_tuple)
|
||||
has_bias, has_mask = variadic_args
|
||||
# reshape to 4D shape
|
||||
query = jnp.reshape(query, (batch, q_seq_len, num_heads, head_dim))
|
||||
key = jnp.reshape(key, (batch, kv_seq_len, num_heads, head_dim))
|
||||
value = jnp.reshape(value, (batch, kv_seq_len, num_heads, head_dim))
|
||||
if has_bias:
|
||||
bias = jnp.reshape(bias, (batch, num_heads, q_seq_len, kv_seq_len))
|
||||
if has_mask:
|
||||
mask = jnp.reshape(mask, (batch, num_heads, q_seq_len, kv_seq_len))
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (batch, num_heads, q_seq_len))
|
||||
else:
|
||||
activation = jnp.reshape(activation, (batch, num_heads, q_seq_len, kv_seq_len))
|
||||
fwd_output = jnp.reshape(fwd_output, (batch, q_seq_len, num_heads, head_dim))
|
||||
grad_output = jnp.reshape(grad_output, (batch, q_seq_len, num_heads, head_dim))
|
||||
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias,
|
||||
mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
|
||||
# reshape to original shape
|
||||
grad_query = jnp.reshape(grad_query, (*batch_tuple, q_seq_len, num_heads, head_dim))
|
||||
grad_key = jnp.reshape(grad_key, (*batch_tuple, kv_seq_len, num_heads, head_dim))
|
||||
grad_value = jnp.reshape(grad_value, (*batch_tuple, kv_seq_len, num_heads, head_dim))
|
||||
grads = (grad_query, grad_key, grad_value)
|
||||
return grads, out_bdims
|
||||
|
||||
# custom partitioning
|
||||
def _get_padded_spec(arg_info):
|
||||
spec = None if arg_info.sharding is None else arg_info.sharding.spec
|
||||
ndim = arg_info.ndim
|
||||
if spec is None:
|
||||
return (None,) * ndim
|
||||
assert len(spec) <= ndim
|
||||
return spec + (None,) * (ndim - len(spec))
|
||||
|
||||
# fwd custom partition
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep out sharding same as query sharding since they have same shape
|
||||
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
# activation sharding
|
||||
if query_spec[-3] == key_spec[-3]:
|
||||
# self attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
|
||||
else:
|
||||
# cross attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
|
||||
return (out_sharding, activation_sharding)
|
||||
|
||||
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep out sharding same as query sharding since they have same shape
|
||||
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
# activation sharding
|
||||
if query_spec[-3] == key_spec[-3]:
|
||||
# self attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
|
||||
else:
|
||||
# cross attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
out_shardings = (out_sharding, activation_sharding)
|
||||
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# bwd custom partition
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep grad query sharding same as query sharding
|
||||
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
|
||||
return out_shardings
|
||||
|
||||
def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep grad query sharding same as query sharding
|
||||
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# Create dot_product_attention_fwd_p for forward operation.
|
||||
_dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd")
|
||||
_dot_product_attention_fwd_p.multiple_results = True
|
||||
_dot_product_attention_fwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_fwd_p))
|
||||
_dot_product_attention_fwd_p.def_abstract_eval(_dot_product_attention_fwd_abstract)
|
||||
|
||||
mlir.register_lowering(
|
||||
_dot_product_attention_fwd_p,
|
||||
_dot_product_attention_fwd_cuda_lowering,
|
||||
platform="gpu",
|
||||
)
|
||||
|
||||
_dot_product_attention_fwd_p_wrapper = core.Primitive("dot_product_attention_fwd_wrapper")
|
||||
_dot_product_attention_fwd_p_wrapper.multiple_results = True
|
||||
_dot_product_attention_fwd_p_wrapper.def_impl(_dot_product_attention_fwd_impl)
|
||||
_dot_product_attention_fwd_p_wrapper.def_abstract_eval(_dot_product_attention_fwd_abstract)
|
||||
|
||||
# Create dot_product_attention_bwd_p for backward operation.
|
||||
_dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd")
|
||||
_dot_product_attention_bwd_p.multiple_results = True
|
||||
_dot_product_attention_bwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_bwd_p))
|
||||
_dot_product_attention_bwd_p.def_abstract_eval(_dot_product_attention_bwd_abstract)
|
||||
|
||||
mlir.register_lowering(
|
||||
_dot_product_attention_bwd_p,
|
||||
_dot_product_attention_bwd_cuda_lowering,
|
||||
platform="gpu",
|
||||
)
|
||||
|
||||
_dot_product_attention_bwd_p_wrapper = core.Primitive("dot_product_attention_bwd_wrapper")
|
||||
_dot_product_attention_bwd_p_wrapper.multiple_results = True
|
||||
_dot_product_attention_bwd_p_wrapper.def_impl(_dot_product_attention_bwd_impl)
|
||||
_dot_product_attention_bwd_p_wrapper.def_abstract_eval(_dot_product_attention_bwd_abstract)
|
||||
|
||||
|
||||
batching.primitive_batchers[_dot_product_attention_fwd_p_wrapper] = _dot_product_attention_fwd_batcher
|
||||
batching.primitive_batchers[_dot_product_attention_bwd_p_wrapper] = _dot_product_attention_bwd_batcher
|
||||
|
||||
_dot_product_attention_fwd_lower.def_partition(
|
||||
infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands,
|
||||
partition=_dot_product_attention_fwd_partition)
|
||||
|
||||
mlir.register_lowering(_dot_product_attention_fwd_p_wrapper,
|
||||
mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True))
|
||||
|
||||
_dot_product_attention_bwd_lower.def_partition(
|
||||
infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands,
|
||||
partition=_dot_product_attention_bwd_partition)
|
||||
|
||||
mlir.register_lowering(_dot_product_attention_bwd_p_wrapper,
|
||||
mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True))
|
||||
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_wrapper)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10))
|
||||
def _dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array,
|
||||
mask: Array,
|
||||
scale: float,
|
||||
seed: int,
|
||||
dropout_rate: float,
|
||||
variadic_args: tuple[bool],
|
||||
is_flash_attention: bool,
|
||||
is_causal_mask: bool):
|
||||
output = _dot_product_attention_fwd(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
return output
|
||||
|
||||
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
|
||||
_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule)
|
||||
|
||||
# User interface
|
||||
def dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
scale: float = 1.0,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
is_causal_mask: bool = False,
|
||||
seed: int = 42,
|
||||
dropout_rate: float = 0.):
|
||||
"""Computes dot-product attention given query, key, and value.
|
||||
This is the core function for applying attention based on
|
||||
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
||||
query and key and combines the values using the attention weights.
|
||||
batch seq num_heads, head_dim // but all assume Q, K and V will have same
|
||||
b q_seq num_heads head_dim -> Q
|
||||
b kv_seq num_heads head_dim -> K
|
||||
b kv_seq num_heads head_dim -> V
|
||||
Args:
|
||||
query: queries for calculating attention with shape of `[batch, q_length,
|
||||
num_heads, qk_depth_per_head]`.
|
||||
key: keys for calculating attention with shape of `[batch, kv_length,
|
||||
num_heads, qk_depth_per_head]`.
|
||||
value: values to be used in attention with shape of `[batch, kv_length,
|
||||
num_heads, v_depth_per_head]`.
|
||||
scale: scale for the query.
|
||||
dropout_rate: dropout rate
|
||||
Returns:
|
||||
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
|
||||
"""
|
||||
# check if query, key and value layout meets cuDNN layout requirement
|
||||
check_qkv_layout(query, key, value)
|
||||
# check if flash attention is supported for this attention pattern
|
||||
is_flash_attention = check_is_flash_attention(query, key)
|
||||
# check if cuDNN is installed and if cuDNN version is sufficient
|
||||
check_cuDNN_version(is_flash_attention)
|
||||
|
||||
variadic_args = (bias is not None, mask is not None)
|
||||
if bias is None:
|
||||
bias = jnp.zeros(0, dtype=query.dtype)
|
||||
if mask is None:
|
||||
mask = jnp.zeros(0, dtype=query.dtype)
|
||||
# TODO: remove this once scale behavior is fixed
|
||||
if scale != 1.0:
|
||||
query = query * scale
|
||||
scale = 1.0
|
||||
output = _dot_product_attention(
|
||||
query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask)
|
||||
return output
|
732
jax/_src/cudnn/fused_attention_stablehlo.py
Normal file
732
jax/_src/cudnn/fused_attention_stablehlo.py
Normal file
@ -0,0 +1,732 @@
|
||||
# Copyright 2023 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial, reduce
|
||||
import operator
|
||||
from typing import Any, Optional
|
||||
import json
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import core, dtypes
|
||||
from jax.interpreters import mlir, xla
|
||||
from jax.interpreters.mlir import ir
|
||||
from jaxlib.hlo_helpers import custom_call
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.core import ShapedArray
|
||||
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax.sharding import Mesh, PartitionSpec, NamedSharding
|
||||
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src import dispatch
|
||||
from jax._src.lib import cuda_versions
|
||||
|
||||
Array = jnp.ndarray
|
||||
DType = jnp.dtype
|
||||
PRNGKey = jnp.ndarray
|
||||
|
||||
def element_type_to_backend_config_type_mapping(dtype):
|
||||
_element_type_to_backend_config_type_mapping = {
|
||||
ir.BF16Type.get(): "BF16",
|
||||
ir.F16Type.get(): "F16",
|
||||
}
|
||||
return _element_type_to_backend_config_type_mapping[dtype]
|
||||
|
||||
def default_layouts(*shapes):
|
||||
return [range(len(shape) - 1, -1, -1) for shape in shapes]
|
||||
|
||||
def create_dot_product_attention_backend_config(batch,
|
||||
num_heads,
|
||||
seq_q,
|
||||
seq_kv,
|
||||
dtype,
|
||||
fmha_scale,
|
||||
seed,
|
||||
dropout_rate,
|
||||
is_flash_attention,
|
||||
is_causal_mask,
|
||||
is_bwd):
|
||||
# b q_seq num_heads head_dim -> Q
|
||||
# b kv_seq num_heads head_dim -> K
|
||||
# b kv_seq num_heads head_dim -> V
|
||||
# b num_heads q_seq kv_seq -> P
|
||||
# b q_seq num_heads head_dim -> O
|
||||
# bmm1: Q @ K -> P
|
||||
# bmm2: P @ V -> O
|
||||
# bmm2Grad1: P @ dO -> dV
|
||||
# bmm2Grad2: dO @ V -> dP
|
||||
# bmm1Grad1: dP @ Q -> dK
|
||||
# bmm1Grad2: dP @ K -> dQ
|
||||
cudnn_fmha_backend_config = {
|
||||
"algorithm": {
|
||||
"algo_id": "0",
|
||||
"math_type": "TENSOR_OP_MATH",
|
||||
"tuning_knobs": {"17": "1", "24": "0"},
|
||||
"is_cudnn_frontend": True,
|
||||
"workspace_size": "0",
|
||||
},
|
||||
"fmha_scale": fmha_scale,
|
||||
"dropout_rate": dropout_rate,
|
||||
"intermediate_tensor_shape": {
|
||||
"element_type": element_type_to_backend_config_type_mapping(dtype),
|
||||
"dimensions": [str(batch), str(num_heads), str(seq_q), str(seq_kv)],
|
||||
"tuple_shapes": [],
|
||||
"layout": {
|
||||
"dim_level_types": [],
|
||||
"dim_unique": [],
|
||||
"dim_ordered": [],
|
||||
"minor_to_major": ["3", "2", "1", "0"],
|
||||
"tiles": [],
|
||||
"element_size_in_bits": "0",
|
||||
"memory_space": "0",
|
||||
"index_primitive_type": "PRIMITIVE_TYPE_INVALID",
|
||||
"pointer_primitive_type": "PRIMITIVE_TYPE_INVALID",
|
||||
"dynamic_shape_metadata_prefix_bytes": "0",
|
||||
},
|
||||
"is_dynamic_dimension": [False, False, False, False],
|
||||
},
|
||||
"seed": seed,
|
||||
"is_flash_attention": is_flash_attention,
|
||||
"is_causal_mask": is_causal_mask,
|
||||
}
|
||||
fwd_dot_number = {
|
||||
"bmm1_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["3"],
|
||||
"lhs_batch_dimensions": ["0", "2"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm2_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
}
|
||||
bwd_dot_number = {
|
||||
"bmm1_grad_gemm1_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["2"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm1_grad_gemm2_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm2_grad_gemm1_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["2"],
|
||||
"rhs_contracting_dimensions": ["1"],
|
||||
"lhs_batch_dimensions": ["0", "1"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
"bmm2_grad_gemm2_dot_dimension_numbers": {
|
||||
"lhs_contracting_dimensions": ["3"],
|
||||
"rhs_contracting_dimensions": ["3"],
|
||||
"lhs_batch_dimensions": ["0", "2"],
|
||||
"rhs_batch_dimensions": ["0", "2"],
|
||||
},
|
||||
}
|
||||
if is_bwd:
|
||||
cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **bwd_dot_number}
|
||||
else:
|
||||
cudnn_fmha_backend_config = {**cudnn_fmha_backend_config, **fwd_dot_number}
|
||||
|
||||
backend_config = {
|
||||
"operation_queue_id":"0",
|
||||
"wait_on_operation_queues":[],
|
||||
"cudnn_fmha_backend_config": cudnn_fmha_backend_config
|
||||
}
|
||||
backend_config = json.dumps(backend_config)
|
||||
return backend_config
|
||||
|
||||
def get_custom_call_name(has_bias, has_mask, has_dropout, is_bwd):
|
||||
index = is_bwd << 3 | has_dropout << 2 | has_mask << 1 | has_bias
|
||||
_custom_name_maps = [
|
||||
# fMHA forward call targets.
|
||||
"__cudnn$fhmaSoftmax",
|
||||
"__cudnn$fhmaScaleBiasSoftmax",
|
||||
"__cudnn$fhmaScaleMaskSoftmax",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmax",
|
||||
"__cudnn$fhmaSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxDropout",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxDropout",
|
||||
# fMHA backward call targets.
|
||||
"__cudnn$fhmaSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxBackward",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxBackward",
|
||||
"__cudnn$fhmaSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleBiasSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleMaskSoftmaxDropoutBackward",
|
||||
"__cudnn$fhmaScaleBiasMaskSoftmaxDropoutBackward"
|
||||
]
|
||||
return _custom_name_maps[index]
|
||||
|
||||
def check_qkv_layout(query, key, value):
|
||||
assert len(query.shape) == len(key.shape) == len(value.shape) == 4, \
|
||||
"query, key and value should have rank 4."
|
||||
|
||||
# Only support fp16 and bf16 here
|
||||
query_dtype = query.dtype
|
||||
key_dtype = key.dtype
|
||||
value_dtype = value.dtype
|
||||
assert query_dtype == key_dtype == value_dtype and query_dtype in [jnp.float16, jnp.bfloat16], \
|
||||
"query, key and value should have same dtype and should be float16 or bfloat16"
|
||||
|
||||
q_batch, q_seq_len, q_num_heads, q_head_dim = query.shape
|
||||
k_batch, k_seq_len, k_num_heads, k_head_dim = key.shape
|
||||
v_batch, v_seq_len, v_num_heads, v_head_dim = value.shape
|
||||
if not((q_batch == k_batch == v_batch)
|
||||
and (k_seq_len == v_seq_len)
|
||||
and (q_num_heads == k_num_heads == v_num_heads)
|
||||
and (q_head_dim == k_head_dim == v_head_dim)):
|
||||
raise ValueError(
|
||||
"query should have layout [batch, q_seq, num_heads, head_dim], " \
|
||||
"key and value should have layout [batch, kv_seq, num_heads, head_dim].")
|
||||
|
||||
def check_is_flash_attention(query, key):
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_sqe_len, _, _ = key.shape
|
||||
# check if attention pattern is supported by flash attention or fused attention
|
||||
if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]:
|
||||
# check if flash attention is supported
|
||||
is_flash_attention = True
|
||||
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
|
||||
# check if regular fused attention is supported
|
||||
is_flash_attention = False
|
||||
else:
|
||||
raise NotImplementedError("Unsupported sequence length and head dim.")
|
||||
return is_flash_attention
|
||||
|
||||
def check_cudnn_version(is_flash_attention):
|
||||
# check if cuDNN is installed and if cuDNN version contraint is satisfied
|
||||
if cuda_versions is None:
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.")
|
||||
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.1 to use fused attention.")
|
||||
|
||||
def _dot_product_attention_fwd(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
output, _ = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
return output
|
||||
|
||||
def _dot_product_attention_fwd_rule(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
res = (query, key, value, bias, mask, activation, output)
|
||||
return output, res
|
||||
|
||||
def _dot_product_attention_bwd_rule(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, res, grad_output):
|
||||
query, key, value, bias, mask, activation, fwd_output = res
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
grads = (grad_query, grad_key, grad_value, None, None)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_impl(query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
# args: {Q, K, V, mask*, bias*}
|
||||
output, activation = _dot_product_attention_fwd_p.bind(
|
||||
query, key, value, bias, mask, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
return output, activation
|
||||
|
||||
def _dot_product_attention_bwd_impl(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p.bind(
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
grads = (grad_query, grad_key, grad_value)
|
||||
return grads
|
||||
|
||||
def _dot_product_attention_fwd_abstract(query, key, value, bias, mask,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_seq_len, _, _ = key.shape
|
||||
output_shape = (batch, q_seq_len, num_heads, head_dim)
|
||||
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
|
||||
softmax_stat_shape = (batch, num_heads, q_seq_len)
|
||||
if q_seq_len > 512:
|
||||
# is flash attention
|
||||
return (
|
||||
ShapedArray(output_shape, query_dtype), # output
|
||||
ShapedArray(softmax_stat_shape, jnp.float32), # softmax_stat
|
||||
)
|
||||
return (
|
||||
ShapedArray(output_shape, query_dtype), # output
|
||||
ShapedArray(activation_shape, query_dtype), # activation
|
||||
)
|
||||
|
||||
def _dot_product_attention_bwd_abstract(query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
*, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_dtype = dtypes.canonicalize_dtype(query.dtype)
|
||||
key_dtype = dtypes.canonicalize_dtype(key.dtype)
|
||||
value_dtype = dtypes.canonicalize_dtype(value.dtype)
|
||||
|
||||
return (
|
||||
ShapedArray(
|
||||
query.shape, query_dtype
|
||||
), # grad query
|
||||
ShapedArray(
|
||||
key.shape, key_dtype
|
||||
), # grad key
|
||||
ShapedArray(
|
||||
value.shape, value_dtype
|
||||
), # part value
|
||||
)
|
||||
|
||||
def _dot_product_attention_fwd_cuda_lowering(ctx, query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
key_shape = key_type.shape
|
||||
value_type = ir.RankedTensorType(value.type)
|
||||
value_shape = value_type.shape
|
||||
|
||||
batch, q_seq_len, num_heads, head_dim = query_shape
|
||||
_, kv_seq_len, _, _ = key_shape
|
||||
|
||||
output_shape = (batch, num_heads, q_seq_len, head_dim)
|
||||
output_layout = (3, 1, 2, 0)
|
||||
output_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
activation_shape = (batch, num_heads, q_seq_len, kv_seq_len)
|
||||
softmax_stat_shape = (batch, num_heads, q_seq_len)
|
||||
scratch_shape = (0,)
|
||||
scratch_type = ir.IntegerType.get_unsigned(8)
|
||||
# get backend config
|
||||
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, False)
|
||||
# {Q, K, V, mask*, bias*}
|
||||
# {output, scratch, activation*}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, has_mask = variadic_args
|
||||
operands = [query, key, value]
|
||||
if has_mask:
|
||||
operands.append(mask)
|
||||
if has_bias:
|
||||
operands.append(bias)
|
||||
custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, False)
|
||||
# create output types and layouts
|
||||
if is_flash_attention:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(softmax_stat_shape, ir.F32Type.get()),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, softmax_stat_shape)
|
||||
else:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(output_shape, query_type.element_type),
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type),
|
||||
ir.RankedTensorType.get(activation_shape, query_type.element_type),
|
||||
]
|
||||
result_layouts = [output_layout] + default_layouts(scratch_shape, activation_shape)
|
||||
# create custom call here
|
||||
out = custom_call(
|
||||
custom_call_name,
|
||||
result_types=result_types,
|
||||
operands=operands,
|
||||
backend_config=backend_config,
|
||||
operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]),
|
||||
result_layouts=result_layouts,
|
||||
)
|
||||
# drop scratch memory
|
||||
# output should be (batch, q_seq_len, num_heads, head_dim) instead of (batch, num_heads, q_seq_len, head_dim)
|
||||
return [hlo.transpose(out.results[0], output_transpose_perm), out.results[2]]
|
||||
|
||||
def _dot_product_attention_bwd_cuda_lowering(ctx, query, key, value, bias, mask, activation, fwd_output, grad_output,
|
||||
scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
query_type = ir.RankedTensorType(query.type)
|
||||
query_shape = query_type.shape
|
||||
key_type = ir.RankedTensorType(key.type)
|
||||
key_shape = key_type.shape
|
||||
value_type = ir.RankedTensorType(value.type)
|
||||
value_shape = value_type.shape
|
||||
activation_type = ir.RankedTensorType(activation.type)
|
||||
activation_shape = activation_type.shape
|
||||
grad_output_type = ir.RankedTensorType(grad_output.type)
|
||||
grad_output_shape = grad_output_type.shape
|
||||
|
||||
batch, q_seq_len, num_heads, head_dim = query_shape
|
||||
_, kv_seq_len, _, _ = key_shape
|
||||
scratch_shape = (0,)
|
||||
scratch_type = ir.IntegerType.get_unsigned(8)
|
||||
|
||||
grad_query_shape = (batch, num_heads, q_seq_len, head_dim)
|
||||
grad_key_shape = (batch, num_heads, kv_seq_len, head_dim)
|
||||
grad_value_shape = (batch, num_heads, kv_seq_len, head_dim)
|
||||
softmax_sum_shape = (batch, num_heads, q_seq_len)
|
||||
grad_layout = (3, 1, 2, 0)
|
||||
grad_transpose_perm = mlir.dense_int_array((0, 2, 1, 3))
|
||||
backend_config = create_dot_product_attention_backend_config(batch, num_heads, q_seq_len, kv_seq_len, query_type.element_type, scale, seed, dropout_rate, is_flash_attention, is_causal_mask, True)
|
||||
# {Q, K, V, activation, dO, mask*, bias*, O*}
|
||||
# {dQ, dK, dV, d_S*, softmax_sum*, d_Q_accum*, scratch, dbias*}
|
||||
has_dropout = dropout_rate > 0
|
||||
has_bias, has_mask = variadic_args
|
||||
# create operands
|
||||
operands = [query, key, value, activation, grad_output]
|
||||
if has_mask:
|
||||
operands.append(mask)
|
||||
if has_bias and is_flash_attention:
|
||||
# flash attention requires bias in the bwd for remat
|
||||
operands.append(bias)
|
||||
if is_flash_attention:
|
||||
operands.append(fwd_output)
|
||||
# get custom call name
|
||||
custom_call_name = get_custom_call_name(has_bias, has_mask, has_dropout, True)
|
||||
|
||||
# create output types and layouts
|
||||
if is_flash_attention:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query
|
||||
ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key
|
||||
ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value
|
||||
ir.RankedTensorType.get(softmax_sum_shape, ir.F32Type.get()), # softmax_sum
|
||||
ir.RankedTensorType.get(grad_query_shape, ir.F32Type.get()), # d_Q_accum
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch
|
||||
]
|
||||
result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(softmax_sum_shape, grad_query_shape, scratch_shape)
|
||||
else:
|
||||
result_types = [
|
||||
ir.RankedTensorType.get(grad_query_shape, query_type.element_type), # grad query
|
||||
ir.RankedTensorType.get(grad_key_shape, key_type.element_type), # grad key
|
||||
ir.RankedTensorType.get(grad_value_shape, value_type.element_type), # grad value
|
||||
ir.RankedTensorType.get(activation_shape, activation_type.element_type), # dS
|
||||
ir.RankedTensorType.get(scratch_shape, scratch_type), # scratch
|
||||
]
|
||||
result_layouts = [grad_layout, grad_layout, grad_layout] + default_layouts(activation_shape, scratch_shape)
|
||||
out = custom_call(
|
||||
custom_call_name,
|
||||
result_types=result_types,
|
||||
operands=operands,
|
||||
backend_config=backend_config,
|
||||
operand_layouts=default_layouts(*[ir.RankedTensorType(operand.type).shape for operand in operands]),
|
||||
result_layouts=result_layouts,
|
||||
)
|
||||
# Only keep dQ, dK and dV here
|
||||
return [hlo.transpose(out.results[0], grad_transpose_perm),
|
||||
hlo.transpose(out.results[1], grad_transpose_perm),
|
||||
hlo.transpose(out.results[2], grad_transpose_perm)]
|
||||
|
||||
# batcher
|
||||
def _check_valid_batch_dims(bdims):
|
||||
for dim in bdims:
|
||||
if dim not in [0, None]:
|
||||
raise NotImplementedError("Currently only support batch_dim in [0, None], " \
|
||||
f"but got {dim=}")
|
||||
|
||||
def _dot_product_attention_fwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, mask = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim
|
||||
|
||||
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
|
||||
*_, kv_seq_len, _, _ = key.shape
|
||||
new_batch = reduce(operator.mul, batch_tuple)
|
||||
has_bias, has_mask = variadic_args
|
||||
# reshape to 4D shape
|
||||
query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
if has_bias:
|
||||
bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
if has_mask:
|
||||
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
|
||||
output, activation = _dot_product_attention_fwd_p_wrapper.bind(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
|
||||
# reshape to original shape
|
||||
output = jnp.reshape(output, (*batch_tuple, q_seq_len, num_heads, head_dim))
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len))
|
||||
else:
|
||||
activation = jnp.reshape(activation, (*batch_tuple, num_heads, q_seq_len, kv_seq_len))
|
||||
return (output, activation), out_bdims
|
||||
|
||||
def _dot_product_attention_bwd_batcher(batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask):
|
||||
_check_valid_batch_dims(batch_dims)
|
||||
query, key, value, bias, mask, activation, fwd_output, grad_output = batched_args
|
||||
query_bdim = batch_dims[0]
|
||||
out_bdims = query_bdim, query_bdim, query_bdim
|
||||
|
||||
*batch_tuple, q_seq_len, num_heads, head_dim = query.shape
|
||||
*_, kv_seq_len, _, _ = key.shape
|
||||
new_batch = reduce(operator.mul, batch_tuple)
|
||||
has_bias, has_mask = variadic_args
|
||||
# reshape to 4D shape
|
||||
query = jnp.reshape(query, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
key = jnp.reshape(key, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
value = jnp.reshape(value, (new_batch, kv_seq_len, num_heads, head_dim))
|
||||
if has_bias:
|
||||
bias = jnp.reshape(bias, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
if has_mask:
|
||||
mask = jnp.reshape(mask, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
if is_flash_attention:
|
||||
activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len))
|
||||
else:
|
||||
activation = jnp.reshape(activation, (new_batch, num_heads, q_seq_len, kv_seq_len))
|
||||
fwd_output = jnp.reshape(fwd_output, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
grad_output = jnp.reshape(grad_output, (new_batch, q_seq_len, num_heads, head_dim))
|
||||
|
||||
grad_query, grad_key, grad_value = _dot_product_attention_bwd_p_wrapper.bind(
|
||||
query, key, value, bias,
|
||||
mask, activation, fwd_output, grad_output,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention,
|
||||
is_causal_mask=is_causal_mask)
|
||||
|
||||
# reshape to original shape
|
||||
grad_query = jnp.reshape(grad_query, (*batch_tuple, q_seq_len, num_heads, head_dim))
|
||||
grad_key = jnp.reshape(grad_key, (*batch_tuple, kv_seq_len, num_heads, head_dim))
|
||||
grad_value = jnp.reshape(grad_value, (*batch_tuple, kv_seq_len, num_heads, head_dim))
|
||||
grads = (grad_query, grad_key, grad_value)
|
||||
return grads, out_bdims
|
||||
|
||||
# custom partitioning
|
||||
def _get_padded_spec(arg_info):
|
||||
spec = None if arg_info.sharding is None else arg_info.sharding.spec
|
||||
ndim = arg_info.ndim
|
||||
if spec is None:
|
||||
return (None,) * ndim
|
||||
assert len(spec) <= ndim
|
||||
return spec + (None,) * (ndim - len(spec))
|
||||
|
||||
# fwd custom partition
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep out sharding same as query sharding since they have same shape
|
||||
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
# activation sharding
|
||||
if query_spec[-3] == key_spec[-3]:
|
||||
# self attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
|
||||
else:
|
||||
# cross attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
|
||||
return (out_sharding, activation_sharding)
|
||||
|
||||
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep out sharding same as query sharding since they have same shape
|
||||
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
# activation sharding
|
||||
if query_spec[-3] == key_spec[-3]:
|
||||
# self attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
|
||||
else:
|
||||
# cross attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
out_shardings = (out_sharding, activation_sharding)
|
||||
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# bwd custom partition
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep grad query sharding same as query sharding
|
||||
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
|
||||
return out_shardings
|
||||
|
||||
def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep grad query sharding same as query sharding
|
||||
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# Create dot_product_attention_fwd_p for forward operation.
|
||||
_dot_product_attention_fwd_p = core.Primitive("dot_product_attention_fwd")
|
||||
_dot_product_attention_fwd_p.multiple_results = True
|
||||
_dot_product_attention_fwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_fwd_p))
|
||||
_dot_product_attention_fwd_p.def_abstract_eval(_dot_product_attention_fwd_abstract)
|
||||
|
||||
mlir.register_lowering(
|
||||
_dot_product_attention_fwd_p,
|
||||
_dot_product_attention_fwd_cuda_lowering,
|
||||
platform="cuda",
|
||||
)
|
||||
|
||||
_dot_product_attention_fwd_p_wrapper = core.Primitive("dot_product_attention_fwd_wrapper")
|
||||
_dot_product_attention_fwd_p_wrapper.multiple_results = True
|
||||
_dot_product_attention_fwd_p_wrapper.def_impl(_dot_product_attention_fwd_impl)
|
||||
_dot_product_attention_fwd_p_wrapper.def_abstract_eval(_dot_product_attention_fwd_abstract)
|
||||
|
||||
# Create dot_product_attention_bwd_p for backward operation.
|
||||
_dot_product_attention_bwd_p = core.Primitive("dot_product_attention_bwd")
|
||||
_dot_product_attention_bwd_p.multiple_results = True
|
||||
_dot_product_attention_bwd_p.def_impl(partial(xla.apply_primitive, _dot_product_attention_bwd_p))
|
||||
_dot_product_attention_bwd_p.def_abstract_eval(_dot_product_attention_bwd_abstract)
|
||||
|
||||
mlir.register_lowering(
|
||||
_dot_product_attention_bwd_p,
|
||||
_dot_product_attention_bwd_cuda_lowering,
|
||||
platform="cuda",
|
||||
)
|
||||
|
||||
_dot_product_attention_bwd_p_wrapper = core.Primitive("dot_product_attention_bwd_wrapper")
|
||||
_dot_product_attention_bwd_p_wrapper.multiple_results = True
|
||||
_dot_product_attention_bwd_p_wrapper.def_impl(_dot_product_attention_bwd_impl)
|
||||
_dot_product_attention_bwd_p_wrapper.def_abstract_eval(_dot_product_attention_bwd_abstract)
|
||||
|
||||
|
||||
batching.primitive_batchers[_dot_product_attention_fwd_p_wrapper] = _dot_product_attention_fwd_batcher
|
||||
batching.primitive_batchers[_dot_product_attention_bwd_p_wrapper] = _dot_product_attention_bwd_batcher
|
||||
|
||||
_dot_product_attention_fwd_lower.def_partition(
|
||||
infer_sharding_from_operands=_dot_product_attention_fwd_infer_sharding_from_operands,
|
||||
partition=_dot_product_attention_fwd_partition)
|
||||
|
||||
mlir.register_lowering(_dot_product_attention_fwd_p_wrapper,
|
||||
mlir.lower_fun(_dot_product_attention_fwd_lower, multiple_results=True))
|
||||
|
||||
_dot_product_attention_bwd_lower.def_partition(
|
||||
infer_sharding_from_operands=_dot_product_attention_bwd_infer_sharding_from_operands,
|
||||
partition=_dot_product_attention_bwd_partition)
|
||||
|
||||
mlir.register_lowering(_dot_product_attention_bwd_p_wrapper,
|
||||
mlir.lower_fun(_dot_product_attention_bwd_lower, multiple_results=True))
|
||||
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_fwd_p_wrapper)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p)
|
||||
dispatch.prim_requires_devices_during_lowering.add(_dot_product_attention_bwd_p_wrapper)
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9, 10))
|
||||
def _dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Array,
|
||||
mask: Array,
|
||||
scale: float,
|
||||
seed: int,
|
||||
dropout_rate: float,
|
||||
variadic_args: tuple[bool],
|
||||
is_flash_attention: bool,
|
||||
is_causal_mask: bool):
|
||||
output = _dot_product_attention_fwd(
|
||||
query, key, value, bias, mask,
|
||||
scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args,
|
||||
is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
return output
|
||||
|
||||
# _dot_product_attention_fwd must have the same func signature as _dot_product_attention
|
||||
_dot_product_attention.defvjp(_dot_product_attention_fwd_rule, _dot_product_attention_bwd_rule)
|
||||
|
||||
# User interface
|
||||
def dot_product_attention(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
scale: float = 1.0,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
is_causal_mask: bool = False,
|
||||
seed: int = 42,
|
||||
dropout_rate: float = 0.):
|
||||
"""Computes dot-product attention given query, key, and value.
|
||||
This is the core function for applying attention based on
|
||||
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
||||
query and key and combines the values using the attention weights.
|
||||
batch seq num_heads, head_dim // but all assume Q, K and V will have same
|
||||
b q_seq num_heads head_dim -> Q
|
||||
b kv_seq num_heads head_dim -> K
|
||||
b kv_seq num_heads head_dim -> V
|
||||
Args:
|
||||
query: queries for calculating attention with shape of `[batch, q_length,
|
||||
num_heads, qk_depth_per_head]`.
|
||||
key: keys for calculating attention with shape of `[batch, kv_length,
|
||||
num_heads, qk_depth_per_head]`.
|
||||
value: values to be used in attention with shape of `[batch, kv_length,
|
||||
num_heads, v_depth_per_head]`.
|
||||
bias: bias to be added to logits with shape of `[batch, num_heads,
|
||||
q_length, kv_length]`.
|
||||
mask: mask used mask out logits with shape of `[batch, num_heads,
|
||||
q_length, kv_length]`.
|
||||
scale: scale for the query.
|
||||
dropout_rate: dropout rate
|
||||
Returns:
|
||||
Output of shape `[batch, q_length, num_heads, v_depth_per_head]`.
|
||||
"""
|
||||
# check if query, key and value layout meets cuDNN layout requirement
|
||||
check_qkv_layout(query, key, value)
|
||||
# check if flash attention is supported for this attention pattern
|
||||
is_flash_attention = check_is_flash_attention(query, key)
|
||||
# check if cuDNN is installed and if cuDNN version is sufficient
|
||||
check_cudnn_version(is_flash_attention)
|
||||
|
||||
variadic_args = (bias is not None, mask is not None)
|
||||
if bias is None:
|
||||
bias = jnp.zeros(0, dtype=query.dtype)
|
||||
if mask is None:
|
||||
mask = jnp.zeros(0, dtype=query.dtype)
|
||||
# TODO: remove this once scale behavior is fixed
|
||||
if scale != 1.0:
|
||||
query = query * scale
|
||||
scale = 1.0
|
||||
output = _dot_product_attention(
|
||||
query, key, value, bias, mask,
|
||||
scale, seed, dropout_rate, variadic_args,
|
||||
is_flash_attention, is_causal_mask)
|
||||
return output
|
13
tests/BUILD
13
tests/BUILD
@ -1430,6 +1430,19 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "fused_attention_stablehlo_test",
|
||||
srcs = ["fused_attention_stablehlo_test.py"],
|
||||
disable_backends = [
|
||||
"tpu",
|
||||
"cpu",
|
||||
],
|
||||
shard_count = 4,
|
||||
deps = [
|
||||
"//jax:fused_attention_stablehlo",
|
||||
],
|
||||
)
|
||||
|
||||
exports_files(
|
||||
[
|
||||
"api_test.py",
|
||||
|
@ -16,6 +16,8 @@ from functools import partial
|
||||
from absl.testing import absltest
|
||||
from typing import Any, Optional
|
||||
import os
|
||||
os.environ['XLA_FLAGS'] = '--xla_dump_disable_metadata --xla_gpu_enable_triton_gemm=false --xla_dump_hlo_as_text --xla_dump_to=./scratch/hlo --xla_dump_hlo_module_re=.*pjit__unnamed_function.* --xla_dump_hlo_pass_re=.* --xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
|
||||
import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -24,41 +26,41 @@ from jax.sharding import PartitionSpec, NamedSharding
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.cudnn.fused_attention_stableHLO import dot_product_attention
|
||||
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
Array = jnp.ndarray
|
||||
|
||||
def f(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
output = dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale=scale,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
is_causal_mask=causal_mask,
|
||||
dropout_rate=dropout_rate)
|
||||
return output
|
||||
|
||||
def f_train(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
output = dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
scale=scale,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
is_causal_mask=causal_mask,
|
||||
dropout_rate=dropout_rate)
|
||||
return output
|
||||
|
||||
def f_train(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
|
||||
out, f_vjp = jax.vjp(
|
||||
partial(f, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
|
||||
query, key, value, bias, None)
|
||||
@ -101,33 +103,27 @@ def g(query: Array,
|
||||
attn_weights = jax.nn.softmax(attn_weights)
|
||||
if dropout_rate > 0.:
|
||||
keep_prob = 1.0 - dropout_rate
|
||||
dropout_shape = list(attn_weights.shape)
|
||||
dropout_shape[-2] = 1
|
||||
dropout_rng = jax.random.PRNGKey(0)
|
||||
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
||||
keep = jnp.broadcast_to(keep, attn_weights.shape)
|
||||
multiplier = (
|
||||
keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=attn_weights.dtype))
|
||||
attn_weights = attn_weights * multiplier
|
||||
dropout_rng = jax.random.key(0)
|
||||
keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
|
||||
attn_weights = jax.lax.select(keep, attn_weights / keep_prob, jnp.zeros_like(attn_weights))
|
||||
|
||||
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
|
||||
|
||||
def g_train(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
bias: Optional[Array] = None,
|
||||
mask: Optional[Array] = None,
|
||||
causal_mask: bool = False,
|
||||
scale: float = 0.5,
|
||||
dropout_rate: float = 0.1) -> Array:
|
||||
out_ref, g_vjp = jax.vjp(
|
||||
partial(g, scale=scale, causal_mask=causal_mask, dropout_rate=dropout_rate),
|
||||
query, key, value, bias, None)
|
||||
query_grad_ref, key_grad_ref, value_grad_ref, _, _ = g_vjp(grad)
|
||||
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref)
|
||||
|
||||
@jtu.with_config(jax_legacy_prng_key='allow')
|
||||
class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
batch_size=[4],
|
||||
@ -136,7 +132,7 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
head_dim=[64, 128],
|
||||
use_bias=[True],
|
||||
is_causal_mask=[False],
|
||||
dropout_rate=[0],
|
||||
dropout_rate=[0, 0.5],
|
||||
scale=[0.5],
|
||||
dtype=[jnp.float16, jnp.bfloat16]
|
||||
)
|
||||
@ -144,15 +140,14 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
|
||||
head_dim: int, use_bias: bool, is_causal_mask: bool,
|
||||
dropout_rate: float, scale: float, dtype: jnp.dtype):
|
||||
if (seq_len == 256 and is_causal_mask):
|
||||
if seq_len == 256 and is_causal_mask:
|
||||
self.skipTest("Fused attention does not support mask generation.")
|
||||
if (seq_len == 256 and head_dim == 128):
|
||||
self.skipTest("Fused attention does not head dim = 128.")
|
||||
if seq_len == 256 and head_dim == 128:
|
||||
self.skipTest("Fused attention does not support head dim = 128.")
|
||||
if len(jax.local_devices()) <= 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
os.environ['XLA_FLAGS'] = '--xla_gpu_enable_cudnn_fmha=true --xla_gpu_fused_attention_use_cudnn_rng=true'
|
||||
|
||||
k1, k2, k3, k4, k5 = jax.random.split(jax.random.PRNGKey(0), 5)
|
||||
k1, k2, k3, k4, k5 = jax.random.split(jax.random.key(0), 5)
|
||||
query = jax.random.normal(
|
||||
k1, (batch_size, seq_len, num_heads, head_dim), dtype=dtype)
|
||||
key = jax.random.normal(
|
||||
@ -197,14 +192,14 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
|
||||
out, (query_grad, key_grad, value_grad) = pjitted_f_train(query, key, value, grad, bias, None)
|
||||
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = pjitted_g_train(query, key, value, grad, bias, None)
|
||||
assert jnp.allclose(out_ref, out, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5)
|
||||
if seq_len > 512:
|
||||
# query_grad in flash attention is not deterministic
|
||||
assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2)
|
||||
self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2)
|
||||
else:
|
||||
assert jnp.allclose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5)
|
||||
assert jnp.allclose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
|
||||
assert jnp.allclose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5)
|
||||
self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5)
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user