replace pjit with jit and only allow shardings on batch/head dim

This commit is contained in:
Cjkkkk 2024-01-18 15:26:27 -08:00
parent 5708fb955b
commit 40eb11bc79
2 changed files with 48 additions and 52 deletions

View File

@ -204,8 +204,9 @@ def check_qkv_layout(query, key, value):
def check_is_flash_attention(query, key):
batch, q_seq_len, num_heads, head_dim = query.shape
_, kv_sqe_len, _, _ = key.shape
is_cross_attention = q_seq_len != kv_sqe_len
# check if attention pattern is supported by flash attention or fused attention
if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]:
if q_seq_len > 512 and kv_sqe_len > 512 and head_dim in [64, 128]:
# check if flash attention is supported
is_flash_attention = True
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
@ -213,14 +214,17 @@ def check_is_flash_attention(query, key):
is_flash_attention = False
else:
raise NotImplementedError("Unsupported sequence length and head dim.")
return is_flash_attention
return is_flash_attention, is_cross_attention
def check_cudnn_version(is_flash_attention):
def check_cudnn_version(is_flash_attention, is_cross_attention):
# check if cuDNN is installed and if cuDNN version contraint is satisfied
if cuda_versions is None:
raise RuntimeError("cuDNN is not detected.")
elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903:
raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.")
elif is_flash_attention:
if not is_cross_attention and cuda_versions.cudnn_get_version() < 8903:
raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.")
if is_cross_attention and cuda_versions.cudnn_get_version() < 8904:
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
raise RuntimeError("JAX requires cuDNN >= 8.9.1 to use fused attention.")
@ -527,49 +531,43 @@ def _get_padded_spec(arg_info):
return (None,) * ndim
assert len(spec) <= ndim
return spec + (None,) * (ndim - len(spec))
# fwd custom partition
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
def _infer_fwd_output_sharding(mesh, arg_shapes):
# only sharding on batch and num_head dim is allowed
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
key_spec = _get_padded_spec(arg_shapes[1])
value_spec = _get_padded_spec(arg_shapes[2])
if not query_spec == key_spec == value_spec:
raise ValueError("Query, key and value should have same sharding.")
seq_spec = query_spec[-3]
head_spec = query_spec[-1]
if seq_spec != None:
raise ValueError("Sharding on sequence dim is not allowed.")
if head_spec != None:
raise ValueError("Sharding on head dim is not allowed.")
# keep out sharding same as query sharding since they have same shape
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
# activation sharding
if query_spec[-3] == key_spec[-3]:
# self attention
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
else:
# cross attention
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
return (out_sharding, activation_sharding)
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
return _infer_fwd_output_sharding(mesh, arg_shapes)
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
key_spec = _get_padded_spec(arg_shapes[1])
# keep out sharding same as query sharding since they have same shape
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
# activation sharding
if query_spec[-3] == key_spec[-3]:
# self attention
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
else:
# cross attention
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
out_shardings = (out_sharding, activation_sharding)
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes)
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
return mesh, impl, out_shardings, arg_shardings
# bwd custom partition
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
def _infer_bwd_output_sharding(mesh, arg_shapes):
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
@ -581,16 +579,12 @@ def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
return out_shardings
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
return _infer_bwd_output_sharding(mesh, arg_shapes)
def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
# (*batch, q_seq, num_head, head)
query_spec = _get_padded_spec(arg_shapes[0])
# (*batch, kv_seq, num_head, head)
key_spec = _get_padded_spec(arg_shapes[1])
# keep grad query sharding same as query sharding
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes)
# args sharding
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
@ -712,9 +706,9 @@ def dot_product_attention(query: Array,
# check if query, key and value layout meets cuDNN layout requirement
check_qkv_layout(query, key, value)
# check if flash attention is supported for this attention pattern
is_flash_attention = check_is_flash_attention(query, key)
is_flash_attention, is_cross_attention = check_is_flash_attention(query, key)
# check if cuDNN is installed and if cuDNN version is sufficient
check_cudnn_version(is_flash_attention)
check_cudnn_version(is_flash_attention, is_cross_attention)
variadic_args = (bias is not None, mask is not None)
if bias is None:

View File

@ -23,7 +23,6 @@ import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec, NamedSharding
from jax.experimental.pjit import pjit
from jax._src import config
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
@ -168,24 +167,27 @@ class DotProductAttentionTest(jtu.JaxTestCase):
devices = devices.reshape((2, 2))
with Mesh(devices, ('dp', 'tp')) as mesh:
qkv_spec = PartitionSpec('dp', None, 'tp', None)
qkv_sharding = NamedSharding(mesh, qkv_spec)
if bias is not None:
bias_spec = PartitionSpec('dp', 'tp', None, None)
else:
bias_spec = None
query = jax.device_put(query, NamedSharding(mesh, qkv_spec))
key = jax.device_put(key, NamedSharding(mesh, qkv_spec))
value = jax.device_put(value, NamedSharding(mesh, qkv_spec))
bias_spec = PartitionSpec()
bias_sharding = NamedSharding(mesh, bias_spec)
replicated = NamedSharding(mesh, PartitionSpec())
query = jax.device_put(query, qkv_sharding)
key = jax.device_put(key, qkv_sharding)
value = jax.device_put(value, qkv_sharding)
if bias is not None:
bias = jax.device_put(bias, NamedSharding(mesh, bias_spec))
grad = jax.device_put(grad, NamedSharding(mesh, qkv_spec))
in_shardings = (qkv_spec, qkv_spec, qkv_spec, qkv_spec, bias_spec, None)
out_shardings = (None, (qkv_spec, qkv_spec, qkv_spec))
pjitted_f_train = pjit(jitted_f_train,
bias = jax.device_put(bias, bias_sharding)
grad = jax.device_put(grad, qkv_sharding)
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, replicated)
out_shardings = (replicated, (qkv_sharding, qkv_sharding, qkv_sharding))
pjitted_f_train = jax.jit(jitted_f_train,
in_shardings=in_shardings,
out_shardings=out_shardings
)
pjitted_g_train = pjit(jitted_g_train,
pjitted_g_train = jax.jit(jitted_g_train,
in_shardings=in_shardings,
out_shardings=out_shardings
)