mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
replace pjit with jit and only allow shardings on batch/head dim
This commit is contained in:
parent
5708fb955b
commit
40eb11bc79
@ -204,8 +204,9 @@ def check_qkv_layout(query, key, value):
|
||||
def check_is_flash_attention(query, key):
|
||||
batch, q_seq_len, num_heads, head_dim = query.shape
|
||||
_, kv_sqe_len, _, _ = key.shape
|
||||
is_cross_attention = q_seq_len != kv_sqe_len
|
||||
# check if attention pattern is supported by flash attention or fused attention
|
||||
if q_seq_len > 512 and q_seq_len == kv_sqe_len and head_dim in [64, 128]:
|
||||
if q_seq_len > 512 and kv_sqe_len > 512 and head_dim in [64, 128]:
|
||||
# check if flash attention is supported
|
||||
is_flash_attention = True
|
||||
elif q_seq_len <= 512 and kv_sqe_len <= 512 and head_dim == 64:
|
||||
@ -213,14 +214,17 @@ def check_is_flash_attention(query, key):
|
||||
is_flash_attention = False
|
||||
else:
|
||||
raise NotImplementedError("Unsupported sequence length and head dim.")
|
||||
return is_flash_attention
|
||||
return is_flash_attention, is_cross_attention
|
||||
|
||||
def check_cudnn_version(is_flash_attention):
|
||||
def check_cudnn_version(is_flash_attention, is_cross_attention):
|
||||
# check if cuDNN is installed and if cuDNN version contraint is satisfied
|
||||
if cuda_versions is None:
|
||||
raise RuntimeError("cuDNN is not detected.")
|
||||
elif is_flash_attention and cuda_versions.cudnn_get_version() < 8903:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.")
|
||||
elif is_flash_attention:
|
||||
if not is_cross_attention and cuda_versions.cudnn_get_version() < 8903:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.3 to use flash attention.")
|
||||
if is_cross_attention and cuda_versions.cudnn_get_version() < 8904:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.4 to use flash cross attention.")
|
||||
elif not is_flash_attention and cuda_versions.cudnn_get_version() < 8901:
|
||||
raise RuntimeError("JAX requires cuDNN >= 8.9.1 to use fused attention.")
|
||||
|
||||
@ -527,49 +531,43 @@ def _get_padded_spec(arg_info):
|
||||
return (None,) * ndim
|
||||
assert len(spec) <= ndim
|
||||
return spec + (None,) * (ndim - len(spec))
|
||||
|
||||
|
||||
# fwd custom partition
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
def _infer_fwd_output_sharding(mesh, arg_shapes):
|
||||
# only sharding on batch and num_head dim is allowed
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
value_spec = _get_padded_spec(arg_shapes[2])
|
||||
if not query_spec == key_spec == value_spec:
|
||||
raise ValueError("Query, key and value should have same sharding.")
|
||||
seq_spec = query_spec[-3]
|
||||
head_spec = query_spec[-1]
|
||||
if seq_spec != None:
|
||||
raise ValueError("Sharding on sequence dim is not allowed.")
|
||||
if head_spec != None:
|
||||
raise ValueError("Sharding on head dim is not allowed.")
|
||||
# keep out sharding same as query sharding since they have same shape
|
||||
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
# activation sharding
|
||||
if query_spec[-3] == key_spec[-3]:
|
||||
# self attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
|
||||
else:
|
||||
# cross attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
|
||||
return (out_sharding, activation_sharding)
|
||||
|
||||
_dot_product_attention_fwd_lower = custom_partitioning(_dot_product_attention_fwd_impl, static_argnums=(5,6,7,8,9,10))
|
||||
def _dot_product_attention_fwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
return _infer_fwd_output_sharding(mesh, arg_shapes)
|
||||
|
||||
def _dot_product_attention_fwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep out sharding same as query sharding since they have same shape
|
||||
out_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
# activation sharding
|
||||
if query_spec[-3] == key_spec[-3]:
|
||||
# self attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], None))
|
||||
else:
|
||||
# cross attention
|
||||
activation_sharding = NamedSharding(mesh, PartitionSpec(*query_spec[:-3], query_spec[-2], query_spec[-3], key_spec[-3]))
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
out_shardings = (out_sharding, activation_sharding)
|
||||
out_shardings = _infer_fwd_output_sharding(mesh, arg_shapes)
|
||||
impl = partial(_dot_product_attention_fwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
variadic_args=variadic_args, is_flash_attention=is_flash_attention, is_causal_mask=is_causal_mask)
|
||||
return mesh, impl, out_shardings, arg_shardings
|
||||
|
||||
# bwd custom partition
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
def _infer_bwd_output_sharding(mesh, arg_shapes):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
@ -581,16 +579,12 @@ def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout
|
||||
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
|
||||
return out_shardings
|
||||
|
||||
_dot_product_attention_bwd_lower = custom_partitioning(_dot_product_attention_bwd_impl, static_argnums=(8,9,10,11,12,13))
|
||||
def _dot_product_attention_bwd_infer_sharding_from_operands(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
return _infer_bwd_output_sharding(mesh, arg_shapes)
|
||||
|
||||
def _dot_product_attention_bwd_partition(scale, seed, dropout_rate, variadic_args, is_flash_attention, is_causal_mask, mesh, arg_shapes, result_shape):
|
||||
# (*batch, q_seq, num_head, head)
|
||||
query_spec = _get_padded_spec(arg_shapes[0])
|
||||
# (*batch, kv_seq, num_head, head)
|
||||
key_spec = _get_padded_spec(arg_shapes[1])
|
||||
# keep grad query sharding same as query sharding
|
||||
grad_query_sharding = NamedSharding(mesh, PartitionSpec(*query_spec))
|
||||
grad_key_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
grad_value_sharding = NamedSharding(mesh, PartitionSpec(*key_spec))
|
||||
out_shardings = (grad_query_sharding, grad_key_sharding, grad_value_sharding)
|
||||
out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes)
|
||||
# args sharding
|
||||
arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes])
|
||||
impl = partial(_dot_product_attention_bwd_impl, scale=scale, seed=seed, dropout_rate=dropout_rate,
|
||||
@ -712,9 +706,9 @@ def dot_product_attention(query: Array,
|
||||
# check if query, key and value layout meets cuDNN layout requirement
|
||||
check_qkv_layout(query, key, value)
|
||||
# check if flash attention is supported for this attention pattern
|
||||
is_flash_attention = check_is_flash_attention(query, key)
|
||||
is_flash_attention, is_cross_attention = check_is_flash_attention(query, key)
|
||||
# check if cuDNN is installed and if cuDNN version is sufficient
|
||||
check_cudnn_version(is_flash_attention)
|
||||
check_cudnn_version(is_flash_attention, is_cross_attention)
|
||||
|
||||
variadic_args = (bias is not None, mask is not None)
|
||||
if bias is None:
|
||||
|
@ -23,7 +23,6 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.sharding import Mesh
|
||||
from jax.sharding import PartitionSpec, NamedSharding
|
||||
from jax.experimental.pjit import pjit
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.cudnn.fused_attention_stablehlo import dot_product_attention
|
||||
@ -168,24 +167,27 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
devices = devices.reshape((2, 2))
|
||||
with Mesh(devices, ('dp', 'tp')) as mesh:
|
||||
qkv_spec = PartitionSpec('dp', None, 'tp', None)
|
||||
qkv_sharding = NamedSharding(mesh, qkv_spec)
|
||||
if bias is not None:
|
||||
bias_spec = PartitionSpec('dp', 'tp', None, None)
|
||||
else:
|
||||
bias_spec = None
|
||||
query = jax.device_put(query, NamedSharding(mesh, qkv_spec))
|
||||
key = jax.device_put(key, NamedSharding(mesh, qkv_spec))
|
||||
value = jax.device_put(value, NamedSharding(mesh, qkv_spec))
|
||||
bias_spec = PartitionSpec()
|
||||
bias_sharding = NamedSharding(mesh, bias_spec)
|
||||
replicated = NamedSharding(mesh, PartitionSpec())
|
||||
query = jax.device_put(query, qkv_sharding)
|
||||
key = jax.device_put(key, qkv_sharding)
|
||||
value = jax.device_put(value, qkv_sharding)
|
||||
if bias is not None:
|
||||
bias = jax.device_put(bias, NamedSharding(mesh, bias_spec))
|
||||
grad = jax.device_put(grad, NamedSharding(mesh, qkv_spec))
|
||||
in_shardings = (qkv_spec, qkv_spec, qkv_spec, qkv_spec, bias_spec, None)
|
||||
out_shardings = (None, (qkv_spec, qkv_spec, qkv_spec))
|
||||
pjitted_f_train = pjit(jitted_f_train,
|
||||
bias = jax.device_put(bias, bias_sharding)
|
||||
grad = jax.device_put(grad, qkv_sharding)
|
||||
in_shardings = (qkv_sharding, qkv_sharding, qkv_sharding, qkv_sharding, bias_sharding, replicated)
|
||||
out_shardings = (replicated, (qkv_sharding, qkv_sharding, qkv_sharding))
|
||||
pjitted_f_train = jax.jit(jitted_f_train,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
||||
pjitted_g_train = pjit(jitted_g_train,
|
||||
pjitted_g_train = jax.jit(jitted_g_train,
|
||||
in_shardings=in_shardings,
|
||||
out_shardings=out_shardings
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user