Merge pull request #26345 from wenscarl:scaled_matmul

PiperOrigin-RevId: 731865430
This commit is contained in:
jax authors 2025-02-27 14:24:48 -08:00
commit c7ca35fe32
7 changed files with 1517 additions and 7 deletions

View File

@ -0,0 +1,696 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import json
import operator
from functools import partial, reduce
from typing import List
# Third-party imports
import jax
import jax.numpy as jnp
import numpy as np
from jax import custom_vjp, lax
from jax._src import core, dispatch, dtypes
from jax._src.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src.lax.lax import ranges_like, remaining
from jax._src.typing import DTypeLike
from jax.interpreters import mlir, xla
from jax.interpreters.mlir import ir
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
Array = jnp.ndarray
block_scaled_dot_name = "__op$block_scaled_dot"
@dataclass
class BlockScaleConfig:
mode: str
block_size: int
data_type: DTypeLike
scale_type: DTypeLike
global_scale: Array | None
infer_only: bool
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def element_type_to_backend_config_type(dtype):
_element_type_to_backend_config_type_mapping = {
ir.BF16Type.get(): "BF16",
ir.F16Type.get(): "F16",
ir.F32Type.get(): "F32",
}
return _element_type_to_backend_config_type_mapping[dtype]
def _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type):
return _scaled_matmul_p.bind(
a, b, a_scale, b_scale, preferred_element_type=preferred_element_type
)
def _scaled_matmul_cuda_lowering(
ctx, a, b, a_scales, b_scales, preferred_element_type
):
lhs_type = ir.RankedTensorType(a.type)
lhs_shape = lhs_type.shape
rhs_type = ir.RankedTensorType(b.type)
rhs_shape = rhs_type.shape
batch, non_contracting_lhs, contracting = lhs_shape
_, non_contracting_rhs, _ = rhs_shape
result_shape = (batch, non_contracting_lhs, non_contracting_rhs)
out_type = mlir.dtype_to_ir_type(preferred_element_type)
result_types = [ir.RankedTensorType.get(result_shape, out_type)]
operands = [a, b, a_scales, b_scales]
backend_config = {
"scaled_dot_backend_config": {
"lhs_batch_dimensions": [0],
"rhs_batch_dimensions": [0],
"dequantize_type": element_type_to_backend_config_type(out_type),
}
}
backend_config = json.dumps(backend_config)
out = mlir.custom_call(
block_scaled_dot_name,
result_types=result_types,
operands=operands,
backend_config=backend_config,
operand_layouts=default_layouts(
*[ir.RankedTensorType(operand.type).shape for operand in operands]
),
result_layouts=default_layouts(result_shape),
)
return [out.result]
def _scaled_matmul_abstract(a, b, a_scale, b_scale, *, preferred_element_type):
a_dtype = dtypes.canonicalize_dtype(a.dtype)
batch, non_contracting_lhs, contracting_lhs = a.shape
_, non_contracting_rhs, _ = b.shape
output_shape = (batch, non_contracting_lhs, non_contracting_rhs)
return (core.ShapedArray(output_shape, preferred_element_type),)
_scaled_matmul_p = core.Primitive("scaled_matmul")
_scaled_matmul_p.multiple_results = True
_scaled_matmul_p.def_impl(partial(xla.apply_primitive, _scaled_matmul_p))
_scaled_matmul_p.def_abstract_eval(_scaled_matmul_abstract)
mlir.register_lowering(
_scaled_matmul_p,
_scaled_matmul_cuda_lowering,
platform="cuda",
)
_scaled_matmul_p_wrapper = core.Primitive("scaled_matmul_wrapper")
_scaled_matmul_p_wrapper.multiple_results = True
_scaled_matmul_p_wrapper.def_impl(_scaled_matmul_impl)
_scaled_matmul_p_wrapper.def_abstract_eval(_scaled_matmul_abstract)
# Given the inputs already sharded as
# ([B], M, K1), ([B], N, K2)
# We define the following rule to apply necessary AllGather based on
# "Input specs", and to define the "Output spec".
# 1. If K1 == K2 != None and N == None:
# - Input spec : ([B], M, K1), ([B], None, K2)
# - Output spec: ([B], M, None) -> AllReduce -> ([B], M, None)
# 2. If K1 == K2 != None and M == N != None:
# - Input spec : ([B], M, K1), ([B], None, K2)
# - Output spec: ([B], M, None) -> ReduceScatter -> ([B], M, N)
# 3. If N == M:
# - Input specs : ([B], M, None), ([B], None, None)
# - Output specs: ([B], M, None)
# 4. If N != M:
# - Input spec : ([B], M, None), ([B], N, None)
# - Output spec: ([B], M, N)
def _check_shardings(shardings):
if len(shardings) != 4:
msg = f"shardings should container 4 inputs, but got {len(shardings)}"
raise TypeError(msg)
lhs, rhs, _, _ = shardings
if len(lhs.spec) != 3 or len(rhs.spec) != 3:
msg = (f'shardings specs rank should be 3, but got lhs: {len(lhs.spec)} '
'and rhs: {len(rhs.spec)}')
raise TypeError(msg)
if lhs.spec[0] != rhs.spec[0]:
msg = ('shardings spec for batch dim should be same, but got lhs: '
'{lhs.spec[0]} and rhs: {rhs.spec[0]}')
raise TypeError(msg)
def _enable_reduce_scatter(lhs, rhs):
batch_spec, m_spec, lhs_k_spec = lhs.spec
_, n_spec, rhs_k_spec = rhs.spec
return (
lhs_k_spec != None
and lhs_k_spec == rhs_k_spec
and m_spec != None
and m_spec == n_spec
)
def _enable_all_reduce(lhs, rhs):
batch_spec, m_spec, lhs_k_spec = lhs.spec
_, n_spec, rhs_k_spec = rhs.spec
return lhs_k_spec != None and lhs_k_spec == rhs_k_spec and n_spec == None
def _get_output_sharding(mesh, shardings):
lhs, rhs = shardings[0], shardings[1]
batch_spec, m_spec, _ = lhs.spec
_, n_spec, _ = rhs.spec
if _enable_reduce_scatter(lhs, rhs):
return [NamedSharding(lhs.mesh, P(*lhs.spec))]
output_specs = (batch_spec, m_spec)
output_specs += (n_spec,) if m_spec != n_spec else (None,)
return [NamedSharding(lhs.mesh, P(*output_specs))]
def _scaled_matmul_infer_sharding_from_operands(
preferred_element_type, mesh, shapes, output_shape
):
shardings = jax.tree.map(lambda x: x.sharding, shapes)
_check_shardings(shardings)
return _get_output_sharding(mesh, shardings)
def supported_in_sharding(mesh, shardings):
lhs_sharding, rhs_sharding = shardings[0], shardings[1]
use_reduce_scatter = _enable_reduce_scatter(lhs_sharding, rhs_sharding)
use_all_reduce = _enable_all_reduce(lhs_sharding, rhs_sharding)
assert not (use_all_reduce and use_reduce_scatter)
lhs_specs, rhs_specs = list(lhs_sharding.spec), list(rhs_sharding.spec)
def named_sharding(lhs, rhs, lhs_specs, rhs_specs):
lhs_sharding = NamedSharding(lhs.mesh, P(*lhs_specs))
rhs_sharding = NamedSharding(rhs.mesh, P(*rhs_specs))
return (lhs_sharding, rhs_sharding, lhs_sharding, rhs_sharding)
if use_all_reduce:
return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs)
if use_reduce_scatter:
rhs_specs[1] = None
return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs)
lhs_specs[2] = None
rhs_specs[2] = None
m_spec, n_spec = lhs_specs[1], rhs_specs[1]
if m_spec == n_spec:
rhs_specs[1] = None
return named_sharding(lhs_sharding, rhs_sharding, lhs_specs, rhs_specs)
def _scaled_matmul_partition(
preferred_element_type, mesh, shapes, output_shape
):
shardings = jax.tree.map(lambda x: x.sharding, shapes)
_check_shardings(shardings)
lhs, rhs = shardings[0], shardings[1]
use_all_reduce = _enable_all_reduce(lhs, rhs)
use_reduce_scatter = _enable_reduce_scatter(lhs, rhs)
lhs_k_spec = lhs.spec[2]
def _scaled_matmul_impl_partition(a, b, a_scale, b_scale):
z = _scaled_matmul_impl(a, b, a_scale, b_scale, preferred_element_type)
if use_reduce_scatter:
z = jax.lax.psum_scatter(
z, lhs_k_spec, scatter_dimension=2, tiled=True
)
if use_all_reduce:
z = jax.lax.psum(z, lhs_k_spec)
return z
out_shardings = _get_output_sharding(mesh, shardings)
arg_shardings = supported_in_sharding(mesh, shardings)
return mesh, _scaled_matmul_impl_partition, out_shardings, arg_shardings
_scaled_matmul_lower = custom_partitioning(
_scaled_matmul_impl, static_argnums=(4,)
)
_scaled_matmul_lower.def_partition(
infer_sharding_from_operands=_scaled_matmul_infer_sharding_from_operands,
partition=_scaled_matmul_partition,
)
def _scaled_matmul_batcher(batched_args, batch_dims, *, preferred_element_type):
assert len(batch_dims) == 4
assert (
batch_dims[0] == batch_dims[1]
and batch_dims[0] == batch_dims[2]
and batch_dims[0] == batch_dims[3]
)
lhs_bdims = batch_dims[0]
out_bdims = (batch_dims[0],)
lhs, rhs, lhs_scales, rhs_scales = batched_args
*batch, lhs_non_contracting, contracting = lhs.shape
*_, _, scales_contracting = lhs_scales.shape
*_, rhs_non_contracting, _ = rhs.shape
new_batch = reduce(operator.mul, batch)
# reshape to 3D shape
lhs = jnp.reshape(lhs, (new_batch, lhs_non_contracting, contracting))
lhs_scales = jnp.reshape(
lhs_scales, (new_batch, lhs_non_contracting, scales_contracting)
)
rhs = jnp.reshape(rhs, (new_batch, rhs_non_contracting, contracting))
rhs_scales = jnp.reshape(
rhs_scales, (new_batch, rhs_non_contracting, scales_contracting)
)
output = jnp.reshape(
_scaled_matmul_p_wrapper.bind(
lhs,
rhs,
lhs_scales,
rhs_scales,
preferred_element_type=preferred_element_type,
)[0],
(*batch, lhs_non_contracting, rhs_non_contracting),
)
return (output,), out_bdims
mlir.register_lowering(
_scaled_matmul_p_wrapper,
mlir.lower_fun(_scaled_matmul_lower, multiple_results=True),
)
dispatch.prim_requires_devices_during_lowering.add(_scaled_matmul_p)
dispatch.prim_requires_devices_during_lowering.add(_scaled_matmul_p_wrapper)
batching.primitive_batchers[_scaled_matmul_p_wrapper] = _scaled_matmul_batcher
batching.primitive_batchers[_scaled_matmul_p] = _scaled_matmul_batcher
@partial(jax.jit, static_argnames=("preferred_element_type",))
def _scaled_matmul(
lhs: Array,
rhs: Array,
lhs_scales: Array,
rhs_scales: Array,
preferred_element_type: DTypeLike = jnp.float32,
) -> Array:
output = _scaled_matmul_p_wrapper.bind(
lhs, rhs, lhs_scales, rhs_scales,
preferred_element_type=preferred_element_type
)
return output[0]
def scaled_matmul_wrapper(
lhs: Array,
rhs: Array,
lhs_scales: Array,
rhs_scales: Array,
preferred_element_type: DTypeLike = jnp.float32,
) -> Array:
"""
Performs scaled matrix multiplication between two 3D arrays, with scaling
factors applied to the matrices.
Args:
lhs (Array): A 3D array of shape (B, M, K).
rhs (Array): A 3D array of shape (B, N, K).
lhs_scales (Array): A 3D array of shape (B, M, K_block).
rhs_scales (Array): A 3D array of shape (B, N, K_block).
preferred_element_type (DTypeLike, optional): The preferred data type
for the computation. Defaults to `jnp.float32`.
Returns:
Array: A 3D array of shape (B, M, N) representing the scaled matrix
multiplication result.
Raises:
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
match the number of columns in `rhs` (`rhs_K`).
Notes:
- The function ensures that the `preferred_element_type` is
danonicalized before passing it to the underlying computation.
- Scaling is applied to the matrices based on the `lhs_scales` and
`rhs_scales` arrays, enabling efficient computations in blocks.
"""
B, M, lhs_K = lhs.shape
_, N, rhs_K = rhs.shape
assert lhs_K == rhs_K
_, _, K_block = lhs_scales.shape
preferred_element_type = dtypes.canonicalize_dtype(
np.dtype(preferred_element_type)
)
out = _scaled_matmul(
lhs,
rhs,
lhs_scales,
rhs_scales,
preferred_element_type=preferred_element_type,
)
return out
def shape_normalization(x, dimension_numbers):
"""
Normalizes the shape of the input tensor `x` to `(B, M, K)`.
This function rearranges and reshapes the input tensor `x` such that:
- `B` represents the batch dimensions.
- `M` represents the non-contracting dimensions.
- `K` represents the contracting dimensions.
The dimensions are reordered and reshaped based on the provided
`dimension_numbers`.
Parameters:
x: The input tensor to normalize.
dimension_numbers: A tuple containing two elements:
- `batch_dims` (tuple): The dimensions of `x` to be treated as batch
dimensions.
- `contracting_dims` (tuple): The dimensions of `x` to be treated as
contracting dimensions.
Returns:
jax.numpy.ndarray: The reshaped tensor with shape `(B, M, K)`
"""
orig_order = list(range(x.ndim))
contracting_dims, batch_dims = dimension_numbers
contracting_order = [d for d in orig_order if d in contracting_dims]
batch_order = [d for d in orig_order if d in batch_dims]
non_contracting_order = [
d
for d in orig_order
if d not in contracting_dims and d not in batch_dims
]
batch_shape = [x.shape[d] for d in batch_order]
rows_shape = [x.shape[d] for d in non_contracting_order]
cols_shape = [x.shape[d] for d in contracting_order]
new_order = batch_order + non_contracting_order + contracting_order
rows, cols, batches = (
np.prod(rows_shape),
np.prod(cols_shape),
np.prod(batch_shape, dtype=int),
)
t = jnp.transpose(x, new_order)
return jnp.reshape(t, (batches, rows, cols))
def compute_dot_output_shape(
lhs_shape, rhs_shape, lhs_dimension_numbers, rhs_dimension_numbers
):
"""
Computes the output shape for a `lax.dot_general`-like operation.
"""
lhs_contract, lhs_batch = lhs_dimension_numbers[0], lhs_dimension_numbers[1]
rhs_contract, rhs_batch = rhs_dimension_numbers[0], rhs_dimension_numbers[1]
output_shape = []
# Add dimensions for batch (assuming the batch dims of LHS and RHS
# should be same)
for i, dim in enumerate(lhs_shape):
if i in lhs_batch:
output_shape.append(dim)
# Add dimensions from the LHS that are non contracting
for i, dim in enumerate(lhs_shape):
if i not in lhs_contract and i not in lhs_batch:
output_shape.append(dim)
# Add dimensions from the RHS that are non contracting
for i, dim in enumerate(rhs_shape):
if i not in rhs_contract and i not in rhs_batch:
output_shape.append(dim)
return tuple(output_shape)
def cast_to_e8m0_with_rounding_up(x):
temp = x.astype(jnp.float32).view(jnp.uint32)
exp = temp >> 23
mant = temp & 0x7FFFFF
is_ru = jnp.logical_and(
jnp.logical_and((mant > 0), (exp != 0xFE)),
~jnp.logical_and((exp == 0), (mant <= 0x400000))
)
exp = jnp.where(is_ru, exp + 1, exp)
new_x = exp.astype(jnp.uint8)
return new_x
def e8m0_to_dtype(x, dtype):
temp = x.astype(jnp.uint32)
exp = temp << 23
new_x = exp.view(jnp.float32)
near_zero_value = 2**-15 if dtype == jnp.float16 else 2**-127
new_x = jnp.where(
new_x == 0, jnp.array(near_zero_value, jnp.float32), new_x
)
return new_x.astype(dtype)
def quantize(x, config):
x_shape = x.shape
contract_dim = x_shape[-1]
block_size = config.block_size
assert contract_dim >= block_size and contract_dim % block_size == 0
x_new_shape = x_shape[:-1] + (x_shape[-1] // block_size, block_size)
x = x.reshape(x_new_shape) # shape = (B, M, K / block_size, block_size)
amax = jnp.max(jnp.abs(x), axis=-1, keepdims=True)
MAX = jnp.finfo(config.data_type).max.astype(x.dtype)
scales = amax / MAX # shape = (B, M, K / block_size, 1)
if config.mode == "mxfp8":
assert config.scale_type == jnp.float8_e8m0fnu
scales_q = cast_to_e8m0_with_rounding_up(scales)
scaled_x = x / e8m0_to_dtype(scales_q, scales.dtype)
elif config.mode == "nvfp4":
assert config.scale_type == jnp.float8_e4m3fn
# shuw(TODO): Add when XLA is ready and e2m1fn is available.
scales_q = scales
scales_x = x
else:
raise ValueError(f"Unrecognized mode: {config.mode}.")
clipped_x = jnp.clip(scaled_x, -MAX, MAX)
x_q = clipped_x.astype(config.data_type)
x_q = x_q.reshape(x_shape) # shape = (B, M, K)
scales_q = jnp.reshape(scales_q, scales_q.shape[:-1]).view(
config.scale_type
)
return x_q, scales_q
def scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
configs):
if preferred_element_type is None:
preferred_element_type = dtypes.result_type(
lhs, rhs, return_weak_type_flag=False
)
else:
preferred_element_type = dtypes.canonicalize_dtype(
np.dtype(preferred_element_type)
)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_dn = (lhs_contract, lhs_batch)
rhs_dn = (rhs_contract, rhs_batch)
lhs_3d = shape_normalization(lhs, lhs_dn)
rhs_3d = shape_normalization(rhs, rhs_dn)
lhs_config, rhs_config = configs[0], configs[1]
lhs_q, lhs_scales = quantize(lhs_3d, lhs_config)
rhs_q, rhs_scales = quantize(rhs_3d, rhs_config)
out = scaled_matmul_wrapper(
lhs_q, rhs_q, lhs_scales, rhs_scales, preferred_element_type
)
expanded_out_shape = compute_dot_output_shape(
lhs.shape, rhs.shape, lhs_dn, rhs_dn
)
expanded_out = jnp.reshape(out, expanded_out_shape)
return expanded_out
def scaled_dot_general_transpose_lhs(
g, x, y, *, dimension_numbers, preferred_element_type, configs,
swap_ans=False
):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
x_ndim = x.aval.ndim
x_kept = remaining(range(x_ndim), x_contract, x_batch)
y_kept = remaining(range(np.ndim(y)), y_contract, y_batch)
if swap_ans:
ans_batch, ans_y, _ = ranges_like(x_batch, y_kept, x_kept)
else:
ans_batch, _, ans_y = ranges_like(x_batch, x_kept, y_kept)
dims = ((ans_y, y_kept), (ans_batch, y_batch))
x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract)))
out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y)
y_dn = (y_kept, y_batch)
g_dn = (ans_y, ans_batch)
y_3d = shape_normalization(y, y_dn)
g_3d = shape_normalization(g, g_dn)
g_config, y_config = configs[0], configs[1]
g_q, g_scales = quantize(g_3d, g_config)
y_q, y_scales = quantize(y_3d, y_config)
out = scaled_matmul_wrapper(
g_q, y_q, g_scales, y_scales, preferred_element_type
)
expanded_out_shape = compute_dot_output_shape(g.shape, y.shape, g_dn, y_dn)
expanded_out = jnp.reshape(out, expanded_out_shape)
x_bar = lax.transpose(expanded_out, tuple(out_axes))
return x_bar
def scaled_dot_general_transpose_rhs(
g, x, y, *, dimension_numbers, preferred_element_type: DTypeLike,
configs: List[BlockScaleConfig]
):
(x_contract, y_contract), (x_batch, y_batch) = dimension_numbers
swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch))
y_bar = scaled_dot_general_transpose_lhs(
g,
y,
x,
dimension_numbers=swapped_dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs,
swap_ans=True,
)
return y_bar
@partial(custom_vjp, nondiff_argnums=(2, 3, 4))
def scaled_dot_general_fn(lhs, rhs, dimension_numbers, preferred_element_type,
configs):
return scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
configs)
def scaled_dot_fwd(lhs, rhs, dimension_numbers, preferred_element_type,
configs):
out = scaled_dot_impl(lhs, rhs, dimension_numbers, preferred_element_type,
configs)
res = (lhs, rhs)
return out, res
def scaled_dot_bwd(dimension_numbers, preferred_element_type, configs, res, g):
(lhs, rhs) = res
args = [g, lhs, rhs]
kw_args = {
"dimension_numbers": dimension_numbers,
"preferred_element_type": preferred_element_type,
}
lhs_kw_args = {
**kw_args,
"configs": [configs[2], configs[1]]
}
rhs_kw_args = {
**kw_args,
"configs": [configs[2], configs[0]]
}
grad_lhs = scaled_dot_general_transpose_lhs(*args, **lhs_kw_args)
grad_rhs = scaled_dot_general_transpose_rhs(*args, **rhs_kw_args)
return (grad_lhs, grad_rhs)
scaled_dot_general_fn.defvjp(scaled_dot_fwd, scaled_dot_bwd)
def ensure_tuple(dimension_numbers):
_to_tuple = lambda x: x if isinstance(x, tuple) else tuple(x)
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers
lhs_contract = _to_tuple(lhs_contract)
rhs_contract = _to_tuple(rhs_contract)
lhs_batch = _to_tuple(lhs_batch)
rhs_batch = _to_tuple(rhs_batch)
return (lhs_contract, rhs_contract), (lhs_batch, rhs_batch)
def _ensure_batch_dim(lhs, rhs, dimension_numbers):
contracting_dims, (lhs_batch, rhs_batch) = dimension_numbers
lhs_batched = lhs
rhs_batched = rhs
if lhs_batch == (): # expand the last dim
lhs_batched = jnp.expand_dims(lhs, axis=lhs.aval.ndim)
lhs_batch = (lhs.aval.ndim,)
if rhs_batch == ():
rhs_batched = jnp.expand_dims(rhs, axis=rhs.aval.ndim)
rhs_batch = (rhs.aval.ndim,)
dn_batched = contracting_dims, (lhs_batch, rhs_batch)
return lhs_batched, rhs_batched, dn_batched
def scaled_dot_general_wrapper(
lhs, rhs, dimension_numbers,
preferred_element_type=jnp.float32,
configs: List[BlockScaleConfig] | None=None,
):
if preferred_element_type not in (jnp.float32, jnp.bfloat16, jnp.float16):
msg = ('Only support preferred_element_type in (f32, bf16, f16), but got '
'{preferred_element_type}')
raise TypeError(msg)
if configs is None:
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
configs = [mxfp8_config, mxfp8_config, mxfp8_config]
dimension_numbers = ensure_tuple(dimension_numbers)
lhs_batched, rhs_batched, dn_batched = _ensure_batch_dim(
lhs, rhs, dimension_numbers
)
out = scaled_dot_general_fn(
lhs_batched, rhs_batched, dn_batched, preferred_element_type, configs,
)
# Expanding batch dims for operands adds a singleton batch dim at axis 0 in
# the output, which we need to squeeze.
if dn_batched != dimension_numbers:
return jnp.squeeze(out, axis=0)
return out

View File

@ -21,7 +21,7 @@ from functools import partial
import operator
import math
import numpy as np
from typing import Any, Literal
from typing import Any, List, Literal
import jax
import jax.numpy as jnp
@ -36,10 +36,14 @@ from jax._src.core import AxisName
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P
from jax._src.cudnn.fused_attention_stablehlo import (
dot_product_attention as cudnn_dot_product_attention, MaskType)
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper as cudnn_scaled_matmul,
scaled_dot_general_wrapper as cudnn_scaled_dot_general,
BlockScaleConfig)
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.numpy import util as numpy_util
from jax._src.typing import Array, ArrayLike, DType
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.ops.special import logsumexp as _logsumexp
@ -1185,3 +1189,109 @@ def dot_product_attention(
raise ValueError(f"Unsupported implementation option: {implementation}")
return jnp.reshape(out, output_shape)
def scaled_matmul(
lhs: Array,
rhs: Array,
lhs_scales: Array,
rhs_scales: Array,
preferred_element_type: DTypeLike = jnp.float32,
) -> Array:
r"""
Performs scaled matrix multiplication between two 3D arrays, with scaling
factors applied to the matrices.
.. math::
\mathrm{ScaledMatmul}(lhs, rhs, lhs_scales, rhs_scales)=lhs_scales \cdot rhs_scales \cdot \mathrm{dot}(lhs, rhs)
Args:
lhs (Array): A 3D array of shape (B, M, K).
rhs (Array): A 3D array of shape (B, N, K).
lhs_scales (Array): A 3D array of shape (B, M, K_block).
rhs_scales (Array): A 3D array of shape (B, N, K_block).
preferred_element_type (DTypeLike, optional): The preferred data type
for the computation. Defaults to `jnp.float32`.
Returns:
Array: A 3D array of shape (B, M, N) representing the scaled matrix
multiplication result.
Raises:
AssertionError: If the number of columns in `lhs` (`lhs_K`) does not
match the number of columns in `rhs` (`rhs_K`).
Notes:
- The function ensures that the `preferred_element_type` is
danonicalized before passing it to the underlying computation.
- Scaling is applied to the matrices based on the `lhs_scales` and
`rhs_scales` arrays, enabling efficient computations in blocks.
"""
B, M, lhs_K = lhs.shape
_, N, rhs_K = rhs.shape
assert lhs_K == rhs_K
_, _, K_block = lhs_scales.shape
preferred_element_type = dtypes.canonicalize_dtype(
np.dtype(preferred_element_type)
)
out = cudnn_scaled_matmul(
lhs,
rhs,
lhs_scales,
rhs_scales,
preferred_element_type=preferred_element_type,
)
return out
def scaled_dot_general(
lhs, rhs,
dimension_numbers,
preferred_element_type=jnp.float32,
configs: List[BlockScaleConfig] | None = None,
implementation: Literal['cudnn'] | None = None,
):
r"""Scaled dot general operation.
Computes the scaled dot general on lhs, rhs with quanitzation specified by configs:
.. math::
\widehat{lhs}, s_a=\mathrm{quantize}(lhs) \\
\widehat{rhs}, s_b=\mathrm{quantize}(rhs) \\
\mathrm{ScaledDot}(lhs, rhs)=s_a \cdot s_b \cdot \mathrm{dot}(\widehat{lhs}, \widehat{rhs})
Args:
lhs: Left-hand side input tensor.
rhs: Right-hand side input tensor.
dimension_numbers: A tuple specifying the contraction and batch dimensions
for the dot general operation. Must follow the format:
`((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims, rhs_batch_dims))`.
preferred_element_type: The preferred output data type. Supported types are
`jnp.float32`, `jnp.bfloat16`, and `jnp.float16`. Defaults to `jnp.float32`.
configs: A list of `BlockScaleConfig` specifying the scaling
configurations for the operation. Defaults to `mxfp8`.
implementation: A string to control which implementation backend to use.
Supported strings are `cudnn` (cuDNN block scaled dot). It defaults
to `None`, which will automatically select the best available backend.
Returns:
The result of the scaled dot general operation.
"""
# Create configs if not provided
if configs is None:
if dtypes.float8_e8m0fnu is None:
raise ValueError("Requires >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
configs = [mxfp8_config for _ in range(3)]
if implementation is None:
implementation = 'cudnn'
match implementation:
case 'cudnn':
out = cudnn_scaled_dot_general(
lhs, rhs, dimension_numbers,
preferred_element_type=preferred_element_type,
configs=configs
)
case _:
raise ValueError(f"Unsupported implementation option: {implementation}")
return out

View File

@ -206,6 +206,8 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
"dot_product_attention_fp8_bwd_wrapper",
):
continue
if p.name == "scaled_matmul_wrapper":
continue
if p.name in tf_not_yet_impl:
self.assertNotIn(
p, tf_impl) # Should not be in both tf_impl and tf_not_yet_impl

View File

@ -37,6 +37,8 @@ from jax._src.nn.functions import (
relu as relu,
relu6 as relu6,
dot_product_attention as dot_product_attention,
scaled_dot_general as scaled_dot_general,
scaled_matmul as scaled_matmul,
selu as selu,
sigmoid as sigmoid,
soft_sign as soft_sign,

View File

@ -1656,6 +1656,15 @@ jax_multiplatform_test(
tags = ["multiaccelerator"],
)
jax_multiplatform_test(
name = "scaled_matmul_stablehlo_test",
srcs = ["scaled_matmul_stablehlo_test.py"],
enable_backends = ["gpu"],
shard_count = {
"gpu": 4,
},
)
jax_py_test(
name = "custom_partitioning_sharding_rule_test",
srcs = ["custom_partitioning_sharding_rule_test.py"],

View File

@ -27,8 +27,14 @@ import scipy.stats
from jax._src import ad_checkpoint
from jax._src import config
from jax._src import core
from jax._src import dtypes as _dtypes
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.cudnn.scaled_matmul_stablehlo import (
quantize,
shape_normalization,
BlockScaleConfig,
)
from jax.test_util import check_grads
from jax import nn
from jax import random
@ -37,9 +43,9 @@ import jax.numpy as jnp
config.parse_flags_with_absl()
def _is_required_cudnn_version_satisfied(min_cudnn_version):
def _is_required_cudnn_version_satisfied(min_cc, min_cudnn_version):
return (
jtu.is_cuda_compute_capability_at_least("8.0") and
jtu.is_cuda_compute_capability_at_least(min_cc) and
cuda_versions is not None and
cuda_versions.cudnn_get_version() >= min_cudnn_version
)
@ -51,9 +57,158 @@ def _check_cudnn_backend(fn, *args, **kwargs):
_cudnn_dbias_error = 'cuDNN only supports bias gradient'
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
assert q_dtype in (jnp.float8_e4m3fn, )
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
scaled_x = x / jnp.broadcast_to(
jnp.asarray(scale, dtype=compute_dtype), x.shape
)
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_x.astype(q_dtype)
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
out = qx.astype(x.dtype) * jnp.broadcast_to(
jnp.asarray(scale, dtype=x.dtype), qx.shape
)
return out
def _generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=jnp.float32,
):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=dtype,
)
k1, k2 = jax.random.split(jax.random.key(123), 2)
a = cast_to_representable(
jax.random.uniform(
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[1].data_type,
)
dn = ((2,), (0,))
a_3d = shape_normalization(a, dn)
b_3d = shape_normalization(b, dn)
a_q, a_scales = quantize(a, configs[0])
b_q, b_scales = quantize(b, configs[1])
return a, b, a_q, b_q, a_scales, b_scales
def create_mxfp8_configs_if_available():
if _dtypes.float8_e8m0fnu is None:
raise unittest.SkipTest("float8_e8m0fnu is not available.")
def _create_mxfp8_config():
return BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
return [_create_mxfp8_config() for _ in range(3)]
@jtu.with_config(jax_legacy_prng_key="allow",
jax_numpy_dtype_promotion="standard")
class NNFunctionsTest(jtu.JaxTestCase):
@parameterized.product(
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
impl=['cudnn',],
)
def testScaledMatmul(self, contract, lhs_non_contract, dtype, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
# Check if float8_e8m0fnu is available
configs = create_mxfp8_configs_if_available()
batch, rhs_non_contract = 4, 256
a, b, a_q, b_q, a_scales, b_scales = _generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=dtype,
)
out = nn.scaled_matmul(a_q, b_q, a_scales, b_scales,
preferred_element_type=dtype)
out_ref = jnp.matmul(a.astype(jnp.float32),
jnp.transpose(b, (0, 2, 1)).astype(jnp.float32))
self.assertArraysAllClose(
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
)
@parameterized.product(
is_training=[True, False],
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
impl=['cudnn',],
)
def testScaledDotGeneral(
self, is_training, output_type, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("10.0", 90700):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible")
configs = create_mxfp8_configs_if_available()
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
k1, k2 = jax.random.split(jax.random.key(0), 2)
a_shape = [2, 256, 96]
b_shape = [2, 96, 160]
dimension_numbers = (([2], [1]), ([0], [0]))
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
configs[1].data_type,
)
scaled_dot_general_fn = partial(
nn.scaled_dot_general, configs=configs
)
def fwd(a, b, is_ref=False):
fn = jax.lax.dot_general if is_ref else scaled_dot_general_fn
y = fn(a, b, dimension_numbers,
preferred_element_type=output_type)
return jnp.sum(y)
if is_training:
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
j_train_ref = jax.jit(
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
)
out, (x_grad, w_grad) = j_train(a, b)
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
else:
j_inference = jax.jit(fwd)
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
out = j_inference(a, b)
out_ref = j_inference_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
@parameterized.product(
dtype=[jnp.bfloat16, jnp.float16],
group_num=[1, 2, 4],
@ -61,7 +216,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
impl=['cudnn', 'xla'],
)
def testDotProductAttention(self, dtype, group_num, use_vmap, impl):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904):
if impl == 'cudnn' and not _is_required_cudnn_version_satisfied("8.0", 8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
if impl == 'cudnn' and dtype == jnp.float32:
raise unittest.SkipTest("cuDNN only supports fp16 or bf16.")
@ -110,7 +265,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
if isinstance(mask_mode, str):
mask_mode = (mask_mode,)
min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904
if not _is_required_cudnn_version_satisfied(min_cudnn_version):
if not _is_required_cudnn_version_satisfied("8.0", min_cudnn_version):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
dtype = jnp.bfloat16
@ -173,7 +328,7 @@ class NNFunctionsTest(jtu.JaxTestCase):
use_vmap=[False, True],
)
def testDotProductAttentionBiasGradient(self, batch_size, use_vmap):
if not _is_required_cudnn_version_satisfied(8904):
if not _is_required_cudnn_version_satisfied("8.0", 8904):
raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.")
dtype = jnp.bfloat16

View File

@ -0,0 +1,536 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from absl.testing import absltest
import re
import numpy as np
import jax
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import PartitionSpec, NamedSharding
from jax._src import config
from jax._src import dtypes as _dtypes
from jax._src import test_util as jtu
from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version
from jax._src.cudnn.scaled_matmul_stablehlo import (
scaled_matmul_wrapper,
scaled_dot_general_wrapper,
shape_normalization,
quantize,
BlockScaleConfig,
)
config.parse_flags_with_absl()
input_shardings = [
(("dp", None, "tp"), ("dp", None, "tp")),
(("dp", None, "tp"), ("dp", None, None)),
(("dp", None, "tp"), ("dp", "tp", None)),
(("dp", None, None), ("dp", "tp", None)),
(("dp", "tp", None), ("dp", "tp", None)),
((None, "dp", "tp"), (None, "dp", "tp")),
((None, "tp", None), (None, "tp", None)),
((None, None, "tp"), (None, "tp", None)),
]
c_name = "__cudnn$blockScaledDot"
expected_hlos = [
(c_name, "all-reduce", "f32[1,512,512]", "replica_groups={{0,1},{2,3}}"),
("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name),
("all-gather", "f8e4m3fn[1,512,512]", "replica_groups=[2,2]<=[4]", c_name),
(c_name,),
("all-gather", "f8e4m3fn[1,256,1024]", "replica_groups=[2,2]<=[4]", c_name),
(c_name, "reduce-scatter", "f32[2,256,512]", "replica_groups={{0,1},{2,3}}"),
("all-gather", "f8e4m3fn[2,512,1024]", "replica_groups=[2,2]<=[4]", c_name),
("all-gather", "f8e4m3fn[2,512,512]", "replica_groups=[2,2]<=[4]", c_name),
]
expected_output_spec = [
PartitionSpec('dp',),
PartitionSpec('dp',),
PartitionSpec('dp', None, 'tp'),
PartitionSpec('dp', None, 'tp'),
PartitionSpec('dp', 'tp', None),
PartitionSpec(None, 'dp', 'tp'),
PartitionSpec(None, 'tp', None),
PartitionSpec(None, None, 'tp'),
]
sharding_configs = {
input_sharding: (hlo, output_spec)
for input_sharding, hlo, output_spec in zip(input_shardings, expected_hlos, expected_output_spec)
}
def quantize_to_qtype(x, q_dtype, compute_dtype, scale):
# Explicitly cast the max values to the compute dtype to avoid unnecessary
# casting to FP32 during the subsequent math operations."
assert q_dtype in (jnp.float8_e4m3fn, )
dtype_max = jnp.finfo(q_dtype).max.astype(compute_dtype)
scaled_x = x / jnp.broadcast_to(
jnp.asarray(scale, dtype=compute_dtype), x.shape
)
clipped_x = jnp.clip(scaled_x, -dtype_max, dtype_max)
return clipped_x.astype(q_dtype)
def quantize_dequantize(x, q_dtype, scale, compute_dtype):
qx = quantize_to_qtype(x, q_dtype, compute_dtype, scale)
out = qx.astype(x.dtype) * jnp.broadcast_to(
jnp.asarray(scale, dtype=x.dtype), qx.shape
)
return out
def generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
configs, dtype=jnp.float32,
):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=dtype,
)
k1, k2 = jax.random.split(jax.random.key(123), 2)
a = cast_to_representable(
jax.random.uniform(
k1, (batch, lhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(
k2, (batch, rhs_non_contract, contract), minval=-1.0, dtype=dtype
),
configs[1].data_type,
)
dn = ((2,), (0,))
a_3d = shape_normalization(a, dn)
b_3d = shape_normalization(b, dn)
a_q, a_scales = quantize(a, configs[0])
b_q, b_scales = quantize(b, configs[1])
return a, b, a_q, b_q, a_scales, b_scales
def shard_and_device_put(
mesh, a_sharding, b_sharding, a, b, a_scales=None, b_scales=None
):
a_spec = PartitionSpec(*a_sharding)
b_spec = PartitionSpec(*b_sharding)
a_named_sharding = NamedSharding(mesh, a_spec)
b_named_sharding = NamedSharding(mesh, b_spec)
a = jax.device_put(a, a_named_sharding)
b = jax.device_put(b, b_named_sharding)
if a_scales is not None:
a_scales = jax.device_put(a_scales, a_named_sharding)
if b_scales is not None:
b_scales = jax.device_put(b_scales, b_named_sharding)
in_shardings = (
a_named_sharding,
b_named_sharding,
)
if a_scales is not None and b_scales is not None:
in_shardings = (
a_named_sharding,
b_named_sharding,
a_named_sharding,
b_named_sharding,
)
return a, b, a_scales, b_scales, in_shardings
return a, b, in_shardings
def create_mxfp8_configs():
if _dtypes.float8_e8m0fnu is None:
return None
mxfp8_config = BlockScaleConfig(
mode='mxfp8',
block_size=32,
data_type=jnp.float8_e4m3fn,
scale_type=jnp.float8_e8m0fnu,
global_scale=None,
infer_only=False
)
return [mxfp8_config for _ in range(3)]
def get_hlo_text(in_shardings, block_scale_configs):
mesh_names = ("dp", "tp")
devices = np.array(jax.local_devices()[:4]).reshape((2, 2))
mesh = Mesh(devices, mesh_names)
_, _, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
2, 512, 1024, 512, block_scale_configs,
)
with mesh:
a_q, b_q, a_scales, b_scales, in_shardings = shard_and_device_put(
mesh, in_shardings[0], in_shardings[1], a_q, b_q, a_scales, b_scales
)
pjit_fn = jax.jit(scaled_matmul_wrapper, in_shardings=in_shardings)
hlo = pjit_fn.lower(a_q, b_q, a_scales, b_scales).compile()
return hlo.as_text()
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class ScaledMatmulTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
return
if _dtypes.float8_e8m0fnu is None:
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
if cudnn_version < 90700:
self.skipTest("Requires >= cuDNN 9.7.0")
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Requires at least Blackwell arch")
mxfp8_configs = create_mxfp8_configs()
@jtu.sample_product(
in_shardings=sharding_configs,
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_collectives(self, in_shardings, block_scale_configs):
if jtu.device_under_test() != "gpu" or len(jax.local_devices()) < 4:
self.skipTest("Partition Test enabled for at least 4 GPUs")
expected_hlo = sharding_configs[in_shardings][0]
hlo_text = get_hlo_text(in_shardings, block_scale_configs)
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in expected_hlo]), flags=re.DOTALL
)
self.assertRegex(
hlo_text, hlo_pattern, msg=f"Failed to find pattern: {expected_hlo}"
)
@jtu.sample_product(
contract=[160, 96],
lhs_non_contract=[240, 100],
dtype=[jnp.float16, jnp.bfloat16, jnp.float32],
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul(
self, contract, lhs_non_contract, dtype, block_scale_configs,
):
batch, rhs_non_contract = 2, 128
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
batch, lhs_non_contract, contract, rhs_non_contract,
block_scale_configs, dtype=dtype,
)
def wrapper(lhs, rhs, lhs_scales, rhs_scales, out_type):
return scaled_matmul_wrapper(
lhs,
rhs,
lhs_scales,
rhs_scales,
preferred_element_type=out_type,
)
j_scaled_matmul = jax.jit(partial(wrapper, out_type=dtype))
hlo_text = (
j_scaled_matmul.lower(a_q, b_q, a_scales, b_scales)
.compile()
.as_text()
)
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
)
self.assertRegex(hlo_text, hlo_pattern)
out = j_scaled_matmul(a_q, b_q, a_scales, b_scales)
out_ref = np.einsum(
"BMK,BNK->BMN", a.astype(jnp.float32), b.astype(jnp.float32)
)
self.assertArraysAllClose(
out, out_ref.astype(dtype), rtol=1e-3, atol=1e-3
)
@jtu.sample_product(
in_shardings=sharding_configs,
block_scale_configs=[mxfp8_configs,],
)
@jtu.run_on_devices("cuda")
def test_scaled_matmul_sharded(self, in_shardings, block_scale_configs):
if len(jax.local_devices()) < 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
batch, contract, non_contract = 2, 1024, 256
a, b, a_q, b_q, a_scales, b_scales = generate_quantized_tensors(
batch, non_contract, contract, non_contract, block_scale_configs,
)
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
expected_output_spec = sharding_configs[in_shardings][1]
with Mesh(devices, ("dp", "tp")) as mesh:
a_q, b_q, a_scales, b_scales, input_shardings = (
shard_and_device_put(
mesh,
in_shardings[0],
in_shardings[1],
a_q,
b_q,
a_scales,
b_scales,
)
)
args = [a_q, b_q, a_scales, b_scales]
j_scaled_matmul = jax.jit(
scaled_matmul_wrapper, in_shardings=input_shardings
)
hlo_compiled = j_scaled_matmul.lower(*args).compile()
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
)
self.assertRegex(hlo_compiled.as_text(), hlo_pattern)
j_ref = jax.jit(
partial(
jax.lax.dot_general,
dimension_numbers=(([2], [2]), ([0], [0])),
),
in_shardings=input_shardings[:2],
)
out = j_scaled_matmul(*args)
out_ref = j_ref(a, b)
expected_output_sharding = NamedSharding(
mesh=mesh, spec=expected_output_spec
)
self.assertArraysAllClose(out, out_ref, rtol=1e-3, atol=1e-3)
self.assertTrue(
out.sharding.is_equivalent_to(expected_output_sharding, out.ndim)
)
@jtu.with_config(jax_numpy_dtype_promotion="standard")
class MxFp8ScaledDotGeneralTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
try:
cudnn_version = check_cudnn_version()
except RuntimeError as e:
self.skipTest(str(e))
return
if _dtypes.float8_e8m0fnu is None:
self.skipTest("Requries >= ml_dtypes 0.5.0 to support float8_e8m0fnu")
if cudnn_version < 90700:
self.skipTest("Requires >= cuDNN 9.7.0")
if not jtu.is_cuda_compute_capability_at_least("10.0"):
self.skipTest("Requires at least Blackwell arch")
block_scale_configs = create_mxfp8_configs()
@jtu.sample_product(
configs=[
# a_shape, b_shape, dimension_numbers, is_training
((1, 32), (2, 32), (([1], [1]), ([], [])), False),
((30, 64), (100, 64), (([1], [1]), ([], [])), False),
((192, 96), (160, 96), (([1], [1]), ([], [])), True),
((64, 128, 4), (128, 128), (([1], [0]), ([], [])), True),
((1, 128, 1024), (1, 1024, 128), (([2], [1]), ([0], [0])), True),
(
(1, 128, 128, 2),
(128, 1, 2, 128),
(([2], [0]), ([0, 3], [1, 2])),
True,
),
],
output_type=[jnp.float16, jnp.bfloat16, jnp.float32],
)
@jtu.run_on_devices("cuda")
def test_dot_general(self, configs, output_type):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
k1, k2 = jax.random.split(jax.random.key(0), 2)
a_shape, b_shape, dimension_numbers, is_training = configs
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0, dtype=output_type),
self.block_scale_configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0, dtype=output_type),
self.block_scale_configs[1].data_type,
)
scaled_dot_general = partial(
scaled_dot_general_wrapper,
configs=self.block_scale_configs
)
def fwd(a, b, is_ref=False):
fn = jax.lax.dot_general if is_ref else scaled_dot_general
y = fn(a, b, dimension_numbers,
preferred_element_type=output_type)
return jnp.sum(y)
if is_training:
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
j_train_ref = jax.jit(
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
)
out, (x_grad, w_grad) = j_train(a, b)
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
else:
j_inference = jax.jit(fwd)
j_inference_ref = jax.jit(partial(fwd, is_ref=True))
out = j_inference(a, b)
out_ref = j_inference_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
@jtu.sample_product(in_shardings=sharding_configs)
@jtu.run_on_devices("cuda")
def test_dot_general_sharded(self, in_shardings):
if len(jax.local_devices()) < 4:
self.skipTest("Require at least 4 devices to run sharding tests.")
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
dimension_numbers = (([2], [2]), ([0], [0]))
a_shape = (2, 128, 512)
b_shape = (2, 256, 512)
k1, k2 = jax.random.split(jax.random.key(0), 2)
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0),
self.block_scale_configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0),
self.block_scale_configs[1].data_type,
)
scaled_dot_general = partial(
scaled_dot_general_wrapper,
configs=self.block_scale_configs
)
def fwd(a, b, is_ref=False):
fn = jax.lax.dot_general if is_ref else scaled_dot_general
y = fn(a, b, dimension_numbers)
# Use a little complex loss function to avoid constant grads, whose
# sharding info might be optimized off and then cause issue with the
# custom scaled_matmul op.
return jnp.sum(jnp.tanh(y))
devices = np.array(jax.local_devices()[:4])
devices = devices.reshape((2, 2))
with Mesh(devices, ("dp", "tp")) as mesh:
a, b, input_shardings = (
shard_and_device_put(
mesh,
in_shardings[0],
in_shardings[1],
a,
b,
)
)
j_train = jax.jit(jax.value_and_grad(partial(fwd), argnums=[0, 1]),
in_shardings=input_shardings)
hlo_text = j_train.lower(a, b).compile().as_text()
hlo_pattern = re.compile(
r".*".join([re.escape(x) for x in ("custom-call", c_name)])
)
j_train_ref = jax.jit(
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1]),
in_shardings=input_shardings
)
out, (x_grad, w_grad) = j_train(a, b)
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e-2)
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
@jtu.sample_product(
configs=[
((1, 128, 256), (1, 128, 256), (0, 0, 0)),
((2, 128, 128), (2, 128, 128), (0, 0, 0)),
((2, 128, 128), (128, 2, 128), (0, 1, 2)),
]
)
@jtu.run_on_devices("cuda")
def test_dot_general_vmap(self, configs):
cast_to_representable = partial(
quantize_dequantize,
scale=jnp.ones((1,)),
compute_dtype=jnp.float32,
)
k1, k2 = jax.random.split(jax.random.key(0), 2)
a_shape, b_shape, vmap_axes = configs
a_axis, b_axis, o_axis = vmap_axes
dimension_numbers = (([1], [1]), ([], []))
a = cast_to_representable(
jax.random.uniform(k1, a_shape, minval=-1.0),
self.block_scale_configs[0].data_type,
)
b = cast_to_representable(
jax.random.uniform(k2, b_shape, minval=-1.0),
self.block_scale_configs[1].data_type,
)
scaled_dot_general = partial(
scaled_dot_general_wrapper,
configs=self.block_scale_configs
)
def fwd(a, b, is_ref=False):
fn = jax.vmap(
jax.lax.dot_general if is_ref else scaled_dot_general,
in_axes=(a_axis, b_axis, None),
out_axes=o_axis,
)
y = fn(a, b, dimension_numbers)
return jnp.sum(y)
j_train = jax.jit(jax.value_and_grad(fwd, argnums=[0, 1]))
j_train_ref = jax.jit(
jax.value_and_grad(partial(fwd, is_ref=True), argnums=[0, 1])
)
out, (x_grad, w_grad) = j_train(a, b)
out_ref, (x_grad_ref, w_grad_ref) = j_train_ref(a, b)
self.assertArraysAllClose(out, out_ref, rtol=1e-2, atol=1e2)
self.assertArraysAllClose(x_grad, x_grad_ref, rtol=1e-2, atol=1e1)
self.assertArraysAllClose(w_grad, w_grad_ref, rtol=1e-2, atol=1e1)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())