Support FP8 for dot_product_attention

This commit is contained in:
wenscarl 2024-12-13 17:52:31 +00:00 committed by shuw
parent 87b66f3c35
commit c67b651314
2 changed files with 928 additions and 116 deletions

File diff suppressed because it is too large Load Diff

View File

@ -38,6 +38,37 @@ from jax._src.cudnn.fused_attention_stablehlo import (
config.parse_flags_with_absl()
Array = jnp.ndarray
fp8_meta_names = [
'amax_dQ', 'amax_dK', 'amax_dV', 'amax_dP',
'descale_q', 'descale_k', 'descale_v', 'descale_s',
'scale_s', 'scale_o', 'descale_o', 'descale_dO',
'descale_dP', 'scale_dQ', 'scale_dK', 'scale_dV', 'scale_dP',
]
fp8_metas = {name: jnp.ones((1, 1, 1, 1), dtype=jnp.float32) for name in fp8_meta_names}
def quantize_to_fp8(x, q_dtype, compute_dtype, scale):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
assert q_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2,
jnp.float8_e4m3fnuz, jnp.float8_e5m2fnuz)
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape)
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_x.astype(q_dtype)
def quantize_dequantize_fp8(x, q_dtype, scale, compute_dtype):
qx = quantize_to_fp8(x, q_dtype, compute_dtype, scale)
out = qx.astype(x.dtype) * jnp.broadcast_to(scale.astype(x.dtype), qx.shape)
return out
cast_to_representable = partial(
quantize_dequantize_fp8, scale=jnp.ones((1,)), compute_dtype=jnp.bfloat16
)
quantize = partial(quantize_to_fp8, scale=jnp.ones((1,)))
def sdpa_train(query: Array,
key: Array,
value: Array,
@ -171,6 +202,25 @@ def sdpa_train_ref(query: Array,
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref)
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref)
def sdpa_train_fp8(query: Array,
key: Array,
value: Array,
grad: Array,
fp8_metas: dict[Array],
scale: float = 0.5,
mask_type: MaskType = MaskType.NO_MASK):
def dot_product_attention_fp8(query, key, value, fp8_metas):
f_p = partial(
dot_product_attention, scale=scale, mask_type=mask_type, use_fp8=True)
return f_p(query, key, value, None, None, None, None, fp8_metas)
out, sdpa_vjp = jax.vjp(
dot_product_attention_fp8, query, key, value, fp8_metas)
grad_amax_s = jnp.ones((1,1,1,1), dtype=jnp.float32)
grad_amax_o = jnp.ones((1,1,1,1), dtype=jnp.float32)
query_grad, key_grad, value_grad, *_ = sdpa_vjp((grad, grad_amax_s, grad_amax_o))
return out[0], (query_grad, key_grad, value_grad)
class DotProductAttentionTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
@ -202,7 +252,7 @@ class DotProductAttentionTest(jtu.JaxTestCase):
def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
head_dim: int, use_mask: bool, use_bias: bool, mask_type: MaskType,
dropout_rate: float, scale: float, dtype: jnp.dtype):
if len(jax.local_devices()) <= 4:
if len(jax.local_devices()) < 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
if use_mask and mask_type != MaskType.NO_MASK:
self.skipTest("Either pass in mask or generate mask directly in cuDNN.")
@ -543,5 +593,116 @@ class DotProductAttentionTest(jtu.JaxTestCase):
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
is_training)
@jtu.with_config(jax_numpy_dtype_promotion='standard')
class DotProductAttentionF8Test(jtu.JaxTestCase):
def setUp(self):
super().setUp()
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
return
if not jtu.is_cuda_compute_capability_at_least("9.0"):
self.skipTest("Requires at least Hopper arch")
@jtu.sample_product(
batch_size=[2, 4],
seq_len=[128, 256],
num_heads=[4, 8],
head_dim=[128],
mask_type=[MaskType.NO_MASK],
scale=[1.0, 0.75],
dtype=[jnp.bfloat16, jnp.float16]
)
@jtu.run_on_devices("cuda")
def test_sdpa_fp8(self, batch_size: int, seq_len: int, num_heads: int,
head_dim: int, mask_type: MaskType,
scale: float, dtype: jnp.dtype):
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
input_shape = (batch_size, seq_len, num_heads, head_dim) # only test the default BTNH
query_h = jax.random.normal(
k1, input_shape, dtype=dtype)
key_h = jax.random.normal(
k2, input_shape, dtype=dtype)
value_h = jax.random.normal(
k3, input_shape, dtype=dtype)
grad_h = jax.random.normal(
k4, input_shape, dtype=dtype)
query = cast_to_representable(query_h, jnp.float8_e4m3fn)
key = cast_to_representable(key_h, jnp.float8_e4m3fn)
value = cast_to_representable(value_h, jnp.float8_e4m3fn)
grad = cast_to_representable(grad_h, jnp.float8_e4m3fn)
query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.float32)
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.float32)
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.float32)
grad_quantized = quantize(grad, jnp.float8_e4m3fn, jnp.float32)
sdpa_train_fp8_p = partial(sdpa_train_fp8, scale=scale, mask_type=mask_type)
jitted_sdpa_train_fp8 = jax.jit(sdpa_train_fp8_p)
jitted_sdpa_train_ref = jax.jit(
partial(
sdpa_train_ref, scale=scale, mask_type=mask_type, dropout_rate=0.0),
)
out, (query_grad, key_grad, value_grad) = \
jitted_sdpa_train_fp8(query_quantized, key_quantized, value_quantized, grad_quantized, fp8_metas)
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \
jitted_sdpa_train_ref(query, key, value, grad, None, None)
self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-1, atol=5e-1)
self.assertArraysAllClose(query_grad_ref, query_grad.astype(dtype), rtol=5e-1, atol=3e0)
self.assertArraysAllClose(key_grad_ref, key_grad.astype(dtype), rtol=5e-1, atol=3e0)
self.assertArraysAllClose(value_grad_ref, value_grad.astype(dtype), rtol=5e-1, atol=5e-1)
@jtu.sample_product(
batch_size=[4, 2],
seq_len=[4, 16],
num_heads=[4, 16],
head_dim=[16, 32],
mask_type=[MaskType.NO_MASK],
qkv_layout=["BNTH", "BTNH"],
scale=[1.0, 0.75],
dtype=[jnp.bfloat16, jnp.float16]
)
@jtu.run_on_devices("cuda")
def test_sdpa_fp8_inference(self, batch_size: int, seq_len: int, num_heads: int,
head_dim: int, mask_type: MaskType, qkv_layout: str,
scale: float, dtype: jnp.dtype):
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
if qkv_layout == "BNTH":
input_shape = (batch_size, num_heads, seq_len, head_dim)
else:
input_shape = (batch_size, seq_len, num_heads, head_dim)
query_h = jax.random.normal(k1, input_shape, dtype=dtype)
key_h = jax.random.normal(k2, input_shape, dtype=dtype)
value_h = jax.random.normal(k3, input_shape, dtype=dtype)
query = cast_to_representable(query_h, jnp.float8_e4m3fn)
key = cast_to_representable(key_h, jnp.float8_e4m3fn)
value = cast_to_representable(value_h, jnp.float8_e4m3fn)
query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.float32)
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.float32)
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.float32)
def dot_product_attention_fp8(query, key, value, fp8_metas):
f_p = partial(
dot_product_attention, scale=scale, mask_type=mask_type, qkv_layout=qkv_layout, use_fp8=True)
return f_p(query, key, value, None, None, None, None, fp8_metas)
jitted_sdpa_inference = jax.jit(
dot_product_attention_fp8,
)
jitted_sdpa_inference_ref = jax.jit(
partial(
dot_product_attention, scale=scale, mask_type=mask_type, qkv_layout=qkv_layout),
)
out, _, _ = jitted_sdpa_inference(query_quantized, key_quantized, value_quantized, fp8_metas)
out_ref = jitted_sdpa_inference_ref(query, key, value)
self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-2, atol=5e-2)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())