mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Support FP8 for dot_product_attention
This commit is contained in:
parent
87b66f3c35
commit
c67b651314
File diff suppressed because it is too large
Load Diff
@ -38,6 +38,37 @@ from jax._src.cudnn.fused_attention_stablehlo import (
|
||||
config.parse_flags_with_absl()
|
||||
Array = jnp.ndarray
|
||||
|
||||
fp8_meta_names = [
|
||||
'amax_dQ', 'amax_dK', 'amax_dV', 'amax_dP',
|
||||
'descale_q', 'descale_k', 'descale_v', 'descale_s',
|
||||
'scale_s', 'scale_o', 'descale_o', 'descale_dO',
|
||||
'descale_dP', 'scale_dQ', 'scale_dK', 'scale_dV', 'scale_dP',
|
||||
]
|
||||
|
||||
fp8_metas = {name: jnp.ones((1, 1, 1, 1), dtype=jnp.float32) for name in fp8_meta_names}
|
||||
|
||||
|
||||
def quantize_to_fp8(x, q_dtype, compute_dtype, scale):
|
||||
# Explicitly cast the max values to the compute dtype to avoid unnecessary
|
||||
# casting to FP32 during the subsequent math operations."
|
||||
assert q_dtype in (jnp.float8_e4m3fn, jnp.float8_e5m2,
|
||||
jnp.float8_e4m3fnuz, jnp.float8_e5m2fnuz)
|
||||
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
|
||||
scaled_x = x / jnp.broadcast_to(scale.astype(compute_dtype), x.shape)
|
||||
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
|
||||
return clipped_x.astype(q_dtype)
|
||||
|
||||
def quantize_dequantize_fp8(x, q_dtype, scale, compute_dtype):
|
||||
qx = quantize_to_fp8(x, q_dtype, compute_dtype, scale)
|
||||
out = qx.astype(x.dtype) * jnp.broadcast_to(scale.astype(x.dtype), qx.shape)
|
||||
return out
|
||||
|
||||
cast_to_representable = partial(
|
||||
quantize_dequantize_fp8, scale=jnp.ones((1,)), compute_dtype=jnp.bfloat16
|
||||
)
|
||||
|
||||
quantize = partial(quantize_to_fp8, scale=jnp.ones((1,)))
|
||||
|
||||
def sdpa_train(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
@ -171,6 +202,25 @@ def sdpa_train_ref(query: Array,
|
||||
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref)
|
||||
return out_ref, (query_grad_ref, key_grad_ref, value_grad_ref)
|
||||
|
||||
def sdpa_train_fp8(query: Array,
|
||||
key: Array,
|
||||
value: Array,
|
||||
grad: Array,
|
||||
fp8_metas: dict[Array],
|
||||
scale: float = 0.5,
|
||||
mask_type: MaskType = MaskType.NO_MASK):
|
||||
def dot_product_attention_fp8(query, key, value, fp8_metas):
|
||||
f_p = partial(
|
||||
dot_product_attention, scale=scale, mask_type=mask_type, use_fp8=True)
|
||||
return f_p(query, key, value, None, None, None, None, fp8_metas)
|
||||
out, sdpa_vjp = jax.vjp(
|
||||
dot_product_attention_fp8, query, key, value, fp8_metas)
|
||||
|
||||
grad_amax_s = jnp.ones((1,1,1,1), dtype=jnp.float32)
|
||||
grad_amax_o = jnp.ones((1,1,1,1), dtype=jnp.float32)
|
||||
query_grad, key_grad, value_grad, *_ = sdpa_vjp((grad, grad_amax_s, grad_amax_o))
|
||||
return out[0], (query_grad, key_grad, value_grad)
|
||||
|
||||
class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -202,7 +252,7 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
def test_sdpa(self, batch_size: int, seq_len: int, num_heads: int,
|
||||
head_dim: int, use_mask: bool, use_bias: bool, mask_type: MaskType,
|
||||
dropout_rate: float, scale: float, dtype: jnp.dtype):
|
||||
if len(jax.local_devices()) <= 4:
|
||||
if len(jax.local_devices()) < 4:
|
||||
self.skipTest("Require at least 4 devices to run sharding tests.")
|
||||
if use_mask and mask_type != MaskType.NO_MASK:
|
||||
self.skipTest("Either pass in mask or generate mask directly in cuDNN.")
|
||||
@ -543,5 +593,116 @@ class DotProductAttentionTest(jtu.JaxTestCase):
|
||||
query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias,
|
||||
is_training)
|
||||
|
||||
|
||||
@jtu.with_config(jax_numpy_dtype_promotion='standard')
|
||||
class DotProductAttentionF8Test(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
try:
|
||||
cudnn_version = check_cudnn_version()
|
||||
except RuntimeError as e:
|
||||
self.skipTest(str(e))
|
||||
return
|
||||
if not jtu.is_cuda_compute_capability_at_least("9.0"):
|
||||
self.skipTest("Requires at least Hopper arch")
|
||||
|
||||
@jtu.sample_product(
|
||||
batch_size=[2, 4],
|
||||
seq_len=[128, 256],
|
||||
num_heads=[4, 8],
|
||||
head_dim=[128],
|
||||
mask_type=[MaskType.NO_MASK],
|
||||
scale=[1.0, 0.75],
|
||||
dtype=[jnp.bfloat16, jnp.float16]
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_fp8(self, batch_size: int, seq_len: int, num_heads: int,
|
||||
head_dim: int, mask_type: MaskType,
|
||||
scale: float, dtype: jnp.dtype):
|
||||
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)
|
||||
input_shape = (batch_size, seq_len, num_heads, head_dim) # only test the default BTNH
|
||||
query_h = jax.random.normal(
|
||||
k1, input_shape, dtype=dtype)
|
||||
key_h = jax.random.normal(
|
||||
k2, input_shape, dtype=dtype)
|
||||
value_h = jax.random.normal(
|
||||
k3, input_shape, dtype=dtype)
|
||||
grad_h = jax.random.normal(
|
||||
k4, input_shape, dtype=dtype)
|
||||
query = cast_to_representable(query_h, jnp.float8_e4m3fn)
|
||||
key = cast_to_representable(key_h, jnp.float8_e4m3fn)
|
||||
value = cast_to_representable(value_h, jnp.float8_e4m3fn)
|
||||
grad = cast_to_representable(grad_h, jnp.float8_e4m3fn)
|
||||
|
||||
query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.float32)
|
||||
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.float32)
|
||||
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.float32)
|
||||
grad_quantized = quantize(grad, jnp.float8_e4m3fn, jnp.float32)
|
||||
|
||||
sdpa_train_fp8_p = partial(sdpa_train_fp8, scale=scale, mask_type=mask_type)
|
||||
jitted_sdpa_train_fp8 = jax.jit(sdpa_train_fp8_p)
|
||||
jitted_sdpa_train_ref = jax.jit(
|
||||
partial(
|
||||
sdpa_train_ref, scale=scale, mask_type=mask_type, dropout_rate=0.0),
|
||||
)
|
||||
|
||||
out, (query_grad, key_grad, value_grad) = \
|
||||
jitted_sdpa_train_fp8(query_quantized, key_quantized, value_quantized, grad_quantized, fp8_metas)
|
||||
out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \
|
||||
jitted_sdpa_train_ref(query, key, value, grad, None, None)
|
||||
|
||||
self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-1, atol=5e-1)
|
||||
self.assertArraysAllClose(query_grad_ref, query_grad.astype(dtype), rtol=5e-1, atol=3e0)
|
||||
self.assertArraysAllClose(key_grad_ref, key_grad.astype(dtype), rtol=5e-1, atol=3e0)
|
||||
self.assertArraysAllClose(value_grad_ref, value_grad.astype(dtype), rtol=5e-1, atol=5e-1)
|
||||
|
||||
@jtu.sample_product(
|
||||
batch_size=[4, 2],
|
||||
seq_len=[4, 16],
|
||||
num_heads=[4, 16],
|
||||
head_dim=[16, 32],
|
||||
mask_type=[MaskType.NO_MASK],
|
||||
qkv_layout=["BNTH", "BTNH"],
|
||||
scale=[1.0, 0.75],
|
||||
dtype=[jnp.bfloat16, jnp.float16]
|
||||
)
|
||||
@jtu.run_on_devices("cuda")
|
||||
def test_sdpa_fp8_inference(self, batch_size: int, seq_len: int, num_heads: int,
|
||||
head_dim: int, mask_type: MaskType, qkv_layout: str,
|
||||
scale: float, dtype: jnp.dtype):
|
||||
k1, k2, k3 = jax.random.split(jax.random.key(0), 3)
|
||||
if qkv_layout == "BNTH":
|
||||
input_shape = (batch_size, num_heads, seq_len, head_dim)
|
||||
else:
|
||||
input_shape = (batch_size, seq_len, num_heads, head_dim)
|
||||
query_h = jax.random.normal(k1, input_shape, dtype=dtype)
|
||||
key_h = jax.random.normal(k2, input_shape, dtype=dtype)
|
||||
value_h = jax.random.normal(k3, input_shape, dtype=dtype)
|
||||
|
||||
query = cast_to_representable(query_h, jnp.float8_e4m3fn)
|
||||
key = cast_to_representable(key_h, jnp.float8_e4m3fn)
|
||||
value = cast_to_representable(value_h, jnp.float8_e4m3fn)
|
||||
|
||||
query_quantized = quantize(query, jnp.float8_e4m3fn, jnp.float32)
|
||||
key_quantized = quantize(key, jnp.float8_e4m3fn, jnp.float32)
|
||||
value_quantized = quantize(value, jnp.float8_e4m3fn, jnp.float32)
|
||||
|
||||
def dot_product_attention_fp8(query, key, value, fp8_metas):
|
||||
f_p = partial(
|
||||
dot_product_attention, scale=scale, mask_type=mask_type, qkv_layout=qkv_layout, use_fp8=True)
|
||||
return f_p(query, key, value, None, None, None, None, fp8_metas)
|
||||
|
||||
jitted_sdpa_inference = jax.jit(
|
||||
dot_product_attention_fp8,
|
||||
)
|
||||
|
||||
jitted_sdpa_inference_ref = jax.jit(
|
||||
partial(
|
||||
dot_product_attention, scale=scale, mask_type=mask_type, qkv_layout=qkv_layout),
|
||||
)
|
||||
out, _, _ = jitted_sdpa_inference(query_quantized, key_quantized, value_quantized, fp8_metas)
|
||||
out_ref = jitted_sdpa_inference_ref(query, key, value)
|
||||
self.assertArraysAllClose(out_ref, out.astype(dtype), rtol=5e-2, atol=5e-2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user